forked from mindspore-Ecosystem/mindspore
!38651 [MS][LITE] add op coders for micro train
Merge pull request !38651 from jianghui58/micro_train_dev
This commit is contained in:
commit
202e6c9b6d
|
@ -10,6 +10,10 @@ function(__install_micro_wrapper)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${NNACL_DIR}/fp32 DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES ${NNACL_DIR}/fp32_grad/activation_grad.h DESTINATION
|
||||
${CODEGEN_ROOT_DIR}/include/nnacl/fp32_grad COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${NNACL_DIR}/fp32_grad/softmax_cross_entropy_with_logits.h DESTINATION
|
||||
${CODEGEN_ROOT_DIR}/include/nnacl/fp32_grad COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${NNACL_DIR}/kernel DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${NNACL_DIR}/infer DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl
|
||||
|
|
|
@ -209,3 +209,31 @@ size_t FusedCastAdamFp16Fp32(int16_t *var16, const float *gradient32, float *m,
|
|||
global_norm_reciprocal, end);
|
||||
return c1;
|
||||
}
|
||||
|
||||
int DoAdam(float *m, float *v, const float *gradient, float *weight, float beta1, float beta2, float *beta1_power,
|
||||
float *beta2_power, float eps, float learning_rate, bool nesterov, int start, int end) {
|
||||
if ((1.f - beta1_power[0]) <= 0.0f) {
|
||||
return NNACL_PARAM_INVALID;
|
||||
}
|
||||
if ((1.f - beta2_power[0]) < 0.0f) {
|
||||
return NNACL_ERRCODE_SQRT_NEGATIVE;
|
||||
}
|
||||
|
||||
float update_lr = learning_rate * sqrtf(1.f - beta2_power[0]) / (1.f - beta1_power[0]);
|
||||
const float one_minus_beta1 = 1.f - beta1;
|
||||
const float one_minus_beta2 = 1.f - beta2;
|
||||
if (nesterov) { // Nadam
|
||||
for (int i = start; i < end; ++i) {
|
||||
m[i] += (gradient[i] - m[i]) * one_minus_beta1;
|
||||
v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2;
|
||||
weight[i] -= update_lr * (m[i] * beta1 + one_minus_beta1 * gradient[i]) / (sqrtf(v[i]) + eps);
|
||||
}
|
||||
} else {
|
||||
for (int i = start; i < end; ++i) {
|
||||
m[i] += (gradient[i] - m[i]) * one_minus_beta1;
|
||||
v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2;
|
||||
weight[i] -= update_lr * m[i] / (sqrtf(v[i]) + eps);
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -40,6 +40,9 @@ size_t FusedCastAdamFp16Fp16(int16_t *var16, const int16_t *gradient16, float *m
|
|||
size_t FusedCastAdamFp16Fp32(int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1,
|
||||
float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start,
|
||||
size_t end);
|
||||
|
||||
int DoAdam(float *m, float *v, const float *gradient, float *weight, float beta1, float beta2, float *beta1_power,
|
||||
float *beta2_power, float eps, float learning_rate, bool nesterov, int start, int end);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32_grad/softmax_cross_entropy_with_logits.h"
|
||||
#include <math.h>
|
||||
|
||||
void ForwardPostExecute(const float *labels, const float *logits, float *grads, float *output2,
|
||||
size_t number_of_classes, int batch_size) {
|
||||
float eps = 1e-6;
|
||||
if (grads != NULL) {
|
||||
for (size_t i = 0; i < (size_t)(batch_size); ++i) {
|
||||
float loss = 0.f;
|
||||
for (size_t j = 0; j < number_of_classes; ++j) {
|
||||
float logit = -logf(logits[i * number_of_classes + j] <= 0.0 ? eps : logits[i * number_of_classes + j]);
|
||||
grads[i * number_of_classes + j] = (logits[i * number_of_classes + j] - labels[i * number_of_classes + j]);
|
||||
loss += labels[i * number_of_classes + j] * logit;
|
||||
}
|
||||
output2[i] = loss;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < (size_t)(batch_size); ++i) {
|
||||
float loss = 0.f;
|
||||
for (size_t j = 0; j < number_of_classes; ++j) {
|
||||
float logit = -logf(logits[i * number_of_classes + j] <= 0.0 ? eps : logits[i * number_of_classes + j]);
|
||||
loss += labels[i * number_of_classes + j] * logit;
|
||||
}
|
||||
output2[i] = loss;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void ForwardPostExecute(const float *labels, const float *logits, float *grads, float *output2,
|
||||
size_t number_of_classes, int batch_size);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,11 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "src/net_runner.h"
|
||||
#include <math.h>
|
||||
#include <getopt.h>
|
||||
#include <stdio.h>
|
||||
#include <malloc.h>
|
||||
#include <cstring>
|
||||
#include <cmath>
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
@ -63,7 +63,7 @@ constexpr static int kElem2Print = 10;
|
|||
class Rescaler : public mindspore::TrainCallBack {
|
||||
public:
|
||||
explicit Rescaler(float scale) : scale_(scale) {
|
||||
if (scale_ == 0) {
|
||||
if (std::fabs(scale) <= std::numeric_limits<float>::epsilon()) {
|
||||
scale_ = 1.0;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "schema/model_generated.h"
|
||||
#include "src/runtime/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/fp32/adam_fp32.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH;
|
||||
|
@ -41,43 +42,13 @@ constexpr static int kGradientIdx = 9;
|
|||
|
||||
int AdamCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
static int DoAdam(float *m, float *v, const float *gradient, float *weight, float beta1, float beta2, float beta1_power,
|
||||
float beta2_power, float eps, float learning_rate, bool nesterov, int start, int end) {
|
||||
if ((1.f - beta1_power) <= 0.0f) {
|
||||
MS_LOG(ERROR) << "divisor cannot be 0 or below";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if ((1.f - beta2_power) < 0.0f) {
|
||||
MS_LOG(ERROR) << "sqrt cannot be negative";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto update_lr = learning_rate * std::sqrt(1.f - beta2_power) / (1.f - beta1_power);
|
||||
const float one_minus_beta1 = 1.f - beta1;
|
||||
const float one_minus_beta2 = 1.f - beta2;
|
||||
if (nesterov) { // Nadam
|
||||
for (int i = start; i < end; ++i) {
|
||||
m[i] += (gradient[i] - m[i]) * one_minus_beta1;
|
||||
v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2;
|
||||
weight[i] -= update_lr * (m[i] * beta1 + one_minus_beta1 * gradient[i]) / (std::sqrt(v[i]) + eps);
|
||||
}
|
||||
} else {
|
||||
for (int i = start; i < end; ++i) {
|
||||
m[i] += (gradient[i] - m[i]) * one_minus_beta1;
|
||||
v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2;
|
||||
weight[i] -= update_lr * m[i] / (std::sqrt(v[i]) + eps);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int AdamCPUKernel::DoExecute(int task_id) {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_10D);
|
||||
auto weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIdx)->MutableData());
|
||||
auto m = reinterpret_cast<float *>(in_tensors_.at(kMomentVector1stIdx)->MutableData());
|
||||
auto v = reinterpret_cast<float *>(in_tensors_.at(kMomentVector2stIdx)->MutableData());
|
||||
auto beta1_power = reinterpret_cast<float *>(in_tensors_.at(kBeta1PowerIdx)->MutableData())[0];
|
||||
auto beta2_power = reinterpret_cast<float *>(in_tensors_.at(kBeta2PowerIdx)->MutableData())[0];
|
||||
auto beta1_power = reinterpret_cast<float *>(in_tensors_.at(kBeta1PowerIdx)->MutableData());
|
||||
auto beta2_power = reinterpret_cast<float *>(in_tensors_.at(kBeta2PowerIdx)->MutableData());
|
||||
auto learning_rate = lr_;
|
||||
auto beta1 = reinterpret_cast<float *>(in_tensors_.at(kBeta1Idx)->MutableData())[0];
|
||||
auto beta2 = reinterpret_cast<float *>(in_tensors_.at(kBeta2Idx)->MutableData())[0];
|
||||
|
@ -146,8 +117,8 @@ int AdamCPUKernel::OptimizerStep() {
|
|||
auto weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIdx)->MutableData());
|
||||
auto m = reinterpret_cast<float *>(in_tensors_.at(kMomentVector1stIdx)->MutableData());
|
||||
auto v = reinterpret_cast<float *>(in_tensors_.at(kMomentVector2stIdx)->MutableData());
|
||||
auto beta1_power = reinterpret_cast<float *>(in_tensors_.at(kBeta1PowerIdx)->MutableData())[0];
|
||||
auto beta2_power = reinterpret_cast<float *>(in_tensors_.at(kBeta2PowerIdx)->MutableData())[0];
|
||||
auto beta1_power = reinterpret_cast<float *>(in_tensors_.at(kBeta1PowerIdx)->MutableData());
|
||||
auto beta2_power = reinterpret_cast<float *>(in_tensors_.at(kBeta2PowerIdx)->MutableData());
|
||||
auto learning_rate = lr_;
|
||||
auto beta1 = reinterpret_cast<float *>(in_tensors_.at(kBeta1Idx)->MutableData())[0];
|
||||
auto beta2 = reinterpret_cast<float *>(in_tensors_.at(kBeta2Idx)->MutableData())[0];
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "src/runtime/kernel_registry.h"
|
||||
#include "nnacl/softmax_parameter.h"
|
||||
#include "nnacl/fp32/softmax_fp32.h"
|
||||
#include "nnacl/fp32_grad/softmax_cross_entropy_with_logits.h"
|
||||
#include "src/runtime/kernel/cpu/fp32_grad/softmax_cross_entropy_with_logits.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
|
@ -29,34 +30,6 @@ using mindspore::schema::PrimitiveType_SoftmaxCrossEntropyWithLogits;
|
|||
namespace mindspore::kernel {
|
||||
int SoftmaxCrossEntropyWithLogitsCPUKernel::Prepare() { return ReSize(); }
|
||||
|
||||
void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *labels, const float *logits, float *grads,
|
||||
float *output2) const {
|
||||
float eps = 1e-6;
|
||||
if (grads != nullptr) {
|
||||
for (size_t i = 0; i < static_cast<size_t>(param_->batch_size_); ++i) {
|
||||
float loss = 0.f;
|
||||
for (size_t j = 0; j < param_->number_of_classes_; ++j) {
|
||||
float logit =
|
||||
-logf(logits[i * param_->number_of_classes_ + j] <= 0.0 ? eps : logits[i * param_->number_of_classes_ + j]);
|
||||
grads[i * param_->number_of_classes_ + j] =
|
||||
(logits[i * param_->number_of_classes_ + j] - labels[i * param_->number_of_classes_ + j]);
|
||||
loss += labels[i * param_->number_of_classes_ + j] * logit;
|
||||
}
|
||||
output2[i] = loss;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < static_cast<size_t>(param_->batch_size_); ++i) {
|
||||
float loss = 0.f;
|
||||
for (size_t j = 0; j < param_->number_of_classes_; ++j) {
|
||||
float logit =
|
||||
-logf(logits[i * param_->number_of_classes_ + j] <= 0.0 ? eps : logits[i * param_->number_of_classes_ + j]);
|
||||
loss += labels[i * param_->number_of_classes_ + j] * logit;
|
||||
}
|
||||
output2[i] = loss;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int SoftmaxCrossEntropyWithLogitsCPUKernel::DoExecute(int task_id) {
|
||||
auto ins = reinterpret_cast<float *>(in_tensors_.at(0)->data());
|
||||
CHECK_NULL_RETURN(ins);
|
||||
|
@ -75,7 +48,7 @@ int SoftmaxCrossEntropyWithLogitsCPUKernel::DoExecute(int task_id) {
|
|||
std::fill(losses_, losses_ + data_size, 0);
|
||||
std::fill(sum_data_, sum_data_ + sm_params_.input_shape_[0], 0);
|
||||
Softmax(ins, losses_, sum_data_, &sm_params_);
|
||||
ForwardPostExecute(labels, losses_, grads, out);
|
||||
ForwardPostExecute(labels, losses_, grads, out, param_->number_of_classes_, param_->batch_size_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -35,8 +35,6 @@ class SoftmaxCrossEntropyWithLogitsCPUKernel : public LossKernel {
|
|||
}
|
||||
~SoftmaxCrossEntropyWithLogitsCPUKernel() override = default;
|
||||
|
||||
void ForwardPostExecute(const float *labels, const float *logits, float *output1, float *output2) const;
|
||||
|
||||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
@ -46,7 +44,6 @@ class SoftmaxCrossEntropyWithLogitsCPUKernel : public LossKernel {
|
|||
SoftmaxCrossEntropyParameter *param_;
|
||||
SoftmaxParameter sm_params_;
|
||||
};
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_
|
||||
|
|
|
@ -100,6 +100,12 @@ set(CODER_OPCODERS_SRC
|
|||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/exp_fp32_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32/prelu_fp32_coder.cc
|
||||
#### nnacl fp32_grad coder
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/activation_grad_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/adam_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/assign_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/biasadd_grad_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.cc
|
||||
#### nnacl int8 coder
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/int8/activation_int8_coder.cc
|
||||
${MICRO_DIR}/coder/opcoders/nnacl/int8/affine_int8_coder.cc
|
||||
|
|
|
@ -120,18 +120,18 @@ int Coder::MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std:
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Coder::Init(const std::string code_mode, const std::string target, bool support_parallel, bool debug_mode) const {
|
||||
int Coder::Init(const std::string &code_mode, const std::string &target, bool support_parallel, bool debug_mode) const {
|
||||
static const std::map<std::string, Target> kTargetMap = {
|
||||
{"x86", kX86}, {"Cortex-M", kCortex_M}, {"ARM32", kARM32}, {"ARM64", kARM64}, {"All", kAllTargets}};
|
||||
static const std::map<std::string, CodeMode> kCodeModeMap = {{"Inference", Inference}, {"Train", Train}};
|
||||
Configurator *config = Configurator::GetInstance();
|
||||
|
||||
auto target_item = kTargetMap.find(target);
|
||||
MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + target);
|
||||
MS_CHECK_TRUE_MSG(target_item != kTargetMap.end(), RET_ERROR, "unsupported target: " + target);
|
||||
config->set_target(target_item->second);
|
||||
|
||||
auto code_item = kCodeModeMap.find(code_mode);
|
||||
MS_CHECK_TRUE_RET_BOOL(code_item != kCodeModeMap.end(), "unsupported code mode: " + code_mode);
|
||||
MS_CHECK_TRUE_MSG(code_item != kCodeModeMap.end(), RET_ERROR, "unsupported code mode: " + code_mode);
|
||||
config->set_code_mode(code_item->second);
|
||||
|
||||
if (support_parallel && config->target() == kCortex_M) {
|
||||
|
|
|
@ -35,7 +35,7 @@ class Coder final {
|
|||
bool support_parallel, bool debug_mode);
|
||||
|
||||
private:
|
||||
int Init(const std::string code_mode, const std::string target, bool support_parallel, bool debug_mode_) const;
|
||||
int Init(const std::string &code_mode, const std::string &target, bool support_parallel, bool debug_mode_) const;
|
||||
int Run(const void *model_buff, size_t size);
|
||||
bool InitPath(const std::string &output_path);
|
||||
std::shared_ptr<CoderSession> session_{nullptr};
|
||||
|
|
|
@ -229,7 +229,7 @@ void CodeInputImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext>
|
|||
// input tensors
|
||||
std::vector<Tensor *> inputs = ctx->graph_inputs();
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
ofs << "static const unsigned char *" << ctx->input_name() + std::to_string(i) << " = 0;\n";
|
||||
ofs << "const unsigned char *" << ctx->input_name() + std::to_string(i) << " = 0;\n";
|
||||
}
|
||||
size_t size = inputs.size();
|
||||
ofs << "int "
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -41,6 +41,7 @@ const char debug_utils_h[] = R"RAW(
|
|||
#include <sys/time.h>
|
||||
#include <time.h>
|
||||
#include <stdint.h>
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
#define MICRO_INFO(content, args...) \
|
||||
{ printf("[INFO] %s|%d: " #content "\r\n", __func__, __LINE__, ##args); }
|
||||
|
@ -62,24 +63,9 @@ enum DataType {
|
|||
DataType_MAX = DataType_DT_UNDEFINED
|
||||
};
|
||||
|
||||
enum Format {
|
||||
Format_NCHW = 0,
|
||||
Format_NHWC = 1,
|
||||
Format_HWKC = 2,
|
||||
Format_HWCK = 3,
|
||||
Format_KCHW = 4,
|
||||
Format_CKHW = 5,
|
||||
Format_KHWC = 6,
|
||||
Format_CHWK = 7,
|
||||
Format_NC4HW4 = 100,
|
||||
Format_NUM_OF_FORMAT = 101,
|
||||
Format_MIN = Format_NCHW,
|
||||
Format_MAX = Format_NUM_OF_FORMAT
|
||||
};
|
||||
|
||||
typedef struct {
|
||||
enum DataType type;
|
||||
enum Format format;
|
||||
enum FormatC format;
|
||||
int ndim;
|
||||
int *dim;
|
||||
void *data;
|
||||
|
@ -140,7 +126,7 @@ static const char *const TypeNames[] = {"DT_FLOAT", "DT_FLOAT16", "DT_INT8", "
|
|||
"", "", "DT_UINT32", "DT_INT64", "DT_UINT16", "",
|
||||
"", "", "", "", "DT_UNDEFINED", ""};
|
||||
|
||||
const char *EnumNameFormat(enum Format e) {
|
||||
const char *EnumNameFormat(enum FormatC e) {
|
||||
switch (e) {
|
||||
case Format_NCHW:
|
||||
return "NCHW";
|
||||
|
@ -160,8 +146,6 @@ const char *EnumNameFormat(enum Format e) {
|
|||
return "CHWK";
|
||||
case Format_NC4HW4:
|
||||
return "NC4HW4";
|
||||
case Format_NUM_OF_FORMAT:
|
||||
return "NUM_OF_FORMAT";
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
|
|
|
@ -254,7 +254,10 @@ int Generator::CodeWeightFile() {
|
|||
cofs << g_hwLicense;
|
||||
cofs << "#include \"" << net_weight_hfile_ << "\"\n\n";
|
||||
cofs << "int " << gThreadNum << " = 1; \n";
|
||||
|
||||
std::vector<Tensor *> inputs = ctx_->graph_inputs();
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
cofs << "extern const unsigned char *" << ctx_->input_name() + std::to_string(i) << ";\n";
|
||||
}
|
||||
if (config_->target() != kCortex_M) {
|
||||
cofs << "unsigned char * " << ctx_->buffer_name() << " = 0; \n";
|
||||
cofs << "unsigned char * " << ctx_->weight_name() << " = 0; \n";
|
||||
|
|
|
@ -14,9 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "coder/opcoders/nnacl/fp32/activation_fp32_coder.h"
|
||||
#include <string>
|
||||
#include "nnacl/fp32/activation_fp32.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
|
||||
#include "coder/opcoders/file_collector.h"
|
||||
#include "coder/opcoders/parallel.h"
|
||||
|
|
|
@ -204,10 +204,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
|
|||
// do const value packing to init
|
||||
if (!params_->a_const_) {
|
||||
code.CodeFunction("InitMatrixA", input_tensor_, a_pack_ptr_, "&mat_mul_parameter", vec_matmul_);
|
||||
if (!params_->b_const_) {
|
||||
init_code.CodeMallocExpression(b_pack_ptr_, b_pack_ptr_size_);
|
||||
init_code.CodeFunction("memset", b_pack_ptr_, 0, b_pack_ptr_size_);
|
||||
} else {
|
||||
if (params_->b_const_) {
|
||||
init_code.CodeBufferOffsetExpression(b_pack_ptr_, context->weight_name(), context->weight_offset_name(),
|
||||
context->weight_size_name(), b_pack_ptr_size_);
|
||||
w_buf_size += b_pack_ptr_size_;
|
||||
|
@ -223,10 +220,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
|
|||
init_code.CodeFunction("InitMatrixB", b_src_str, b_pack_ptr_, "&mat_mul_parameter", vec_matmul_);
|
||||
}
|
||||
if (!params_->b_const_) {
|
||||
if (!params_->a_const_) {
|
||||
init_code.CodeMallocExpression(a_pack_str, a_pack_ptr_size_);
|
||||
init_code.CodeFunction("memset", a_pack_ptr_, 0, a_pack_ptr_size_);
|
||||
} else {
|
||||
if (params_->a_const_) {
|
||||
init_code.CodeBufferOffsetExpression(a_pack_ptr_, context->weight_name(), context->weight_offset_name(),
|
||||
context->weight_size_name(), a_pack_ptr_size_);
|
||||
w_buf_size += a_pack_ptr_size_;
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/opcoders/nnacl/fp32_grad/activation_grad_coder.h"
|
||||
#include "nnacl/fp32_grad/activation_grad.h"
|
||||
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
|
||||
#include "coder/opcoders/file_collector.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType_ActivationGrad;
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
int ActivationGradCoder::DoCode(CoderContext *const context) {
|
||||
MS_CHECK_TRUE(input_tensors_.size() == DIMENSION_2D, "inputs size is not equal to two");
|
||||
Tensor *input0 = input_tensors_.at(0);
|
||||
Tensor *input1 = input_tensors_.at(1);
|
||||
// attribute
|
||||
auto *activation_parameter = reinterpret_cast<ActivationParameter *>(parameter_);
|
||||
int count = input_tensor_->ElementsNum();
|
||||
Collect(context,
|
||||
{
|
||||
"nnacl/fp32_grad/activation_grad.h",
|
||||
},
|
||||
{
|
||||
"activation_grad.c",
|
||||
});
|
||||
NNaclFp32Serializer code;
|
||||
|
||||
switch (activation_parameter->type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
code.CodeFunction("ReluGrad", input0, input1, count, output_tensor_);
|
||||
break;
|
||||
case schema::ActivationType_ELU:
|
||||
code.CodeFunction("EluGrad", input0, input1, count, output_tensor_, activation_parameter->alpha_);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Activation type error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_LOG(DEBUG) << "ActivationGradCode has been called";
|
||||
context->AppendCode(code.str());
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_ActivationGrad,
|
||||
CPUOpCoderCreator<ActivationGradCoder>)
|
||||
} // namespace mindspore::lite::micro::nnacl
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_ACTIVATION_GRAD_CODER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_ACTIVATION_GRAD_CODER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "coder/opcoders/op_coder.h"
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
class ActivationGradCoder final : public OperatorCoder {
|
||||
public:
|
||||
ActivationGradCoder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const LiteGraph::Node *node, size_t node_index, Target target)
|
||||
: OperatorCoder(in_tensors, out_tensors, node, node_index, target) {}
|
||||
|
||||
~ActivationGradCoder() override = default;
|
||||
|
||||
int Prepare(CoderContext *const context) override { return RET_OK; }
|
||||
|
||||
int DoCode(CoderContext *const context) override;
|
||||
};
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_ACTIVATION_GRAD_CODER_H_
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/opcoders/nnacl/fp32_grad/adam_coder.h"
|
||||
#include "nnacl/fp32_grad/optimizer.h"
|
||||
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
|
||||
#include "coder/opcoders/file_collector.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType_Adam;
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
namespace {
|
||||
constexpr int kWeightIdx = 0;
|
||||
constexpr int kMomentVector1stIdx = 1;
|
||||
constexpr int kMomentVector2stIdx = 2;
|
||||
constexpr int kBeta1PowerIdx = 3;
|
||||
constexpr int kBeta2PowerIdx = 4;
|
||||
constexpr int kLearningRateIdx = 5;
|
||||
constexpr int kBeta1Idx = 6;
|
||||
constexpr int kBeta2Idx = 7;
|
||||
constexpr int kEpsilonIdx = 8;
|
||||
constexpr int kGradientIdx = 9;
|
||||
} // namespace
|
||||
int AdamCoder::DoCode(CoderContext *const context) {
|
||||
MS_CHECK_TRUE(input_tensors_.size() >= DIMENSION_10D, "inputs size is less than 10");
|
||||
auto weight = input_tensors_.at(kWeightIdx);
|
||||
auto m = input_tensors_.at(kMomentVector1stIdx);
|
||||
auto v = input_tensors_.at(kMomentVector2stIdx);
|
||||
auto beta1_power = input_tensors_.at(kBeta1PowerIdx);
|
||||
auto beta2_power = input_tensors_.at(kBeta2PowerIdx);
|
||||
auto learning_rate = reinterpret_cast<float *>(
|
||||
input_tensors_.at(kLearningRateIdx)->MutableData())[0]; // use model origin lr, unsupported to config
|
||||
auto beta1 = reinterpret_cast<float *>(input_tensors_.at(kBeta1Idx)->MutableData())[0];
|
||||
auto beta2 = reinterpret_cast<float *>(input_tensors_.at(kBeta2Idx)->MutableData())[0];
|
||||
auto eps = reinterpret_cast<float *>(input_tensors_.at(kEpsilonIdx)->MutableData())[0];
|
||||
auto gradient = input_tensors_.at(kGradientIdx);
|
||||
int length = input_tensors_.at(kWeightIdx)->ElementsNum();
|
||||
|
||||
// attribute
|
||||
auto *adam_param = reinterpret_cast<AdamParameter *>(parameter_);
|
||||
Collect(context,
|
||||
{
|
||||
"nnacl/fp32/adam_fp32.h",
|
||||
},
|
||||
{
|
||||
"adam_fp32.c",
|
||||
});
|
||||
NNaclFp32Serializer code;
|
||||
code.CodeFunction("DoAdam", m, v, gradient, weight, beta1, beta2, beta1_power, beta2_power, eps, learning_rate,
|
||||
adam_param->use_nesterov_, 0, length);
|
||||
context->AppendCode(code.str());
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Adam, CPUOpCoderCreator<AdamCoder>)
|
||||
} // namespace mindspore::lite::micro::nnacl
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_ADAM_CODER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_ADAM_CODER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "coder/opcoders/op_coder.h"
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
class AdamCoder final : public OperatorCoder {
|
||||
public:
|
||||
AdamCoder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const LiteGraph::Node *node, size_t node_index, Target target)
|
||||
: OperatorCoder(in_tensors, out_tensors, node, node_index, target) {}
|
||||
|
||||
~AdamCoder() override = default;
|
||||
|
||||
int Prepare(CoderContext *const context) override { return RET_OK; }
|
||||
|
||||
int DoCode(CoderContext *const context) override;
|
||||
};
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_ADAM_CODER_H_
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/opcoders/nnacl/fp32_grad/assign_coder.h"
|
||||
#include <string>
|
||||
#include "schema/inner/ops_generated.h"
|
||||
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
using mindspore::schema::PrimitiveType_Assign;
|
||||
|
||||
int AssignCoder::Prepare(CoderContext *const context) { return RET_OK; }
|
||||
|
||||
int AssignCoder::DoCode(CoderContext *const context) {
|
||||
MS_CHECK_TRUE(input_tensors_.size() == DIMENSION_2D, "inputs size is not equal to two");
|
||||
Tensor *input0 = input_tensors_.at(0);
|
||||
Tensor *input1 = input_tensors_.at(1);
|
||||
if (input0->Size() != input1->Size()) {
|
||||
MS_LOG(ERROR) << "input0 size: " << input0->Size() << ", input1 size: " << input1->Size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
NNaclFp32Serializer code;
|
||||
// Get Tensor Pointer
|
||||
std::string input0_str = allocator_->GetRuntimeAddr(input0);
|
||||
std::string input1_str = allocator_->GetRuntimeAddr(input1);
|
||||
size_t data_size = input0->Size();
|
||||
// assign, just assign input1'data to input0
|
||||
code.CodeFunction("memcpy", input0_str, input1_str, data_size);
|
||||
context->AppendCode(code.str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Assign, CPUOpCoderCreator<AssignCoder>)
|
||||
} // namespace mindspore::lite::micro::nnacl
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_ASSIGN_CODER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_ASSIGN_CODER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "coder/opcoders/op_coder.h"
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
class AssignCoder final : public OperatorCoder {
|
||||
public:
|
||||
AssignCoder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const LiteGraph::Node *node, size_t node_index, Target target)
|
||||
: OperatorCoder(in_tensors, out_tensors, node, node_index, target) {}
|
||||
~AssignCoder() override = default;
|
||||
|
||||
int Prepare(CoderContext *const context) override;
|
||||
|
||||
int DoCode(CoderContext *const context) override;
|
||||
};
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_ASSIGN_CODER_H_
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/opcoders/nnacl/fp32_grad/biasadd_grad_coder.h"
|
||||
#include <string>
|
||||
#include "schema/inner/ops_generated.h"
|
||||
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
using mindspore::schema::PrimitiveType_BiasAddGrad;
|
||||
|
||||
int BiasAddGradCoder::Prepare(CoderContext *const context) {
|
||||
auto dims = input_tensor_->shape();
|
||||
auto *bias_param = reinterpret_cast<ArithmeticParameter *>(parameter_);
|
||||
bias_param->ndim_ = dims.size();
|
||||
for (unsigned int i = 0; i < bias_param->ndim_; i++) {
|
||||
bias_param->in_shape0_[i] = dims[i];
|
||||
bias_param->out_shape_[i] = 1; // 1 dimension for N,H,W,
|
||||
}
|
||||
bias_param->out_shape_[bias_param->ndim_ - 1] = dims[bias_param->ndim_ - 1];
|
||||
for (auto i = bias_param->ndim_; i < DIMENSION_4D; i++) {
|
||||
bias_param->in_shape0_[i] = 0;
|
||||
bias_param->out_shape_[i] = 0;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BiasAddGradCoder::DoCode(CoderContext *const context) {
|
||||
auto *bias_param = reinterpret_cast<ArithmeticParameter *>(parameter_);
|
||||
size_t nhw_size = 1;
|
||||
size_t channels = bias_param->in_shape0_[bias_param->ndim_ - 1]; // C in NHWC
|
||||
for (size_t i = 0; i < bias_param->ndim_ - 1; i++) {
|
||||
nhw_size *= static_cast<size_t>(bias_param->in_shape0_[i]);
|
||||
}
|
||||
|
||||
size_t total_size = channels * nhw_size;
|
||||
|
||||
NNaclFp32Serializer code;
|
||||
// Get Tensor Pointer
|
||||
std::string input_str = allocator_->GetRuntimeAddr(input_tensor_);
|
||||
std::string output_str = allocator_->GetRuntimeAddr(output_tensor_);
|
||||
|
||||
code << "\t\tfor (size_t c = 0; c < " << channels << "; ++c) {\n";
|
||||
code << "\t\t\t(" << output_str << ")[c] = 0;\n";
|
||||
code << "\t\t\tfor (size_t offset = 0; offset < " << total_size << "; offset += " << channels << ") {\n";
|
||||
code << "\t\t\t\t(" << output_str << ")[c] += (" << input_str << ")[offset + c];\n";
|
||||
code << "\t\t\t}\n";
|
||||
code << "\t\t}\n";
|
||||
|
||||
context->AppendCode(code.str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_BiasAddGrad, CPUOpCoderCreator<BiasAddGradCoder>)
|
||||
} // namespace mindspore::lite::micro::nnacl
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_BIASADD_GRAD_CODER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_BIASADD_GRAD_CODER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "coder/opcoders/op_coder.h"
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
class BiasAddGradCoder final : public OperatorCoder {
|
||||
public:
|
||||
BiasAddGradCoder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const LiteGraph::Node *node, size_t node_index, Target target)
|
||||
: OperatorCoder(in_tensors, out_tensors, node, node_index, target) {}
|
||||
~BiasAddGradCoder() override = default;
|
||||
|
||||
int Prepare(CoderContext *const context) override;
|
||||
|
||||
int DoCode(CoderContext *const context) override;
|
||||
};
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_BIASADD_GRAD_CODER_H_
|
|
@ -0,0 +1,80 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.h"
|
||||
#include <string>
|
||||
#include "nnacl/fp32_grad/softmax_crossentropy_parameter.h"
|
||||
#include "coder/opcoders/file_collector.h"
|
||||
#include "schema/inner/ops_generated.h"
|
||||
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
using mindspore::schema::PrimitiveType_SoftmaxCrossEntropyWithLogits;
|
||||
|
||||
int SoftmaxCrossEntropyWithLogitsCoder::Prepare(CoderContext *const context) {
|
||||
MS_CHECK_TRUE(input_tensor_ != nullptr, "input_tensor is nullptr.");
|
||||
size_t data_size = input_tensor_->ElementsNum();
|
||||
auto dims = input_tensor_->shape();
|
||||
auto *softmax_cross_entropy_param = reinterpret_cast<SoftmaxCrossEntropyParameter *>(parameter_);
|
||||
softmax_cross_entropy_param->n_dim_ = DIMENSION_2D;
|
||||
CHECK_LESS_RETURN(dims.size(), DIMENSION_2D);
|
||||
softmax_cross_entropy_param->number_of_classes_ = dims.at(1);
|
||||
softmax_cross_entropy_param->batch_size_ = dims.at(0);
|
||||
|
||||
losses_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, data_size * sizeof(float), kWorkspace));
|
||||
sum_data_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, dims[0] * sizeof(float), kWorkspace));
|
||||
softmax_params_.n_dim_ = DIMENSION_2D;
|
||||
softmax_params_.element_size_ = data_size;
|
||||
softmax_params_.axis_ = 1;
|
||||
for (size_t i = 0; i < dims.size(); i++) {
|
||||
softmax_params_.input_shape_[i] = dims.at(i);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SoftmaxCrossEntropyWithLogitsCoder::DoCode(CoderContext *const context) {
|
||||
MS_CHECK_TRUE(input_tensors_.size() == DIMENSION_2D, "inputs size is not equal to two");
|
||||
Collect(context,
|
||||
{
|
||||
"nnacl/fp32/softmax_fp32.h",
|
||||
"nnacl/fp32_grad/softmax_cross_entropy_with_logits.h",
|
||||
},
|
||||
{
|
||||
"softmax_fp32.c",
|
||||
"softmax_cross_entropy_with_logits.c",
|
||||
});
|
||||
NNaclFp32Serializer code, init_code;
|
||||
code.CodeStruct("softmax_params", softmax_params_);
|
||||
// Get Tensor Pointer
|
||||
std::string in_str = allocator_->GetRuntimeAddr(input_tensor_);
|
||||
std::string labels_str = allocator_->GetRuntimeAddr(input_tensors_.at(1));
|
||||
std::string out_str = allocator_->GetRuntimeAddr(output_tensor_);
|
||||
std::string grad_str = "NULL";
|
||||
if (output_tensors_.size() > 1) {
|
||||
grad_str = allocator_->GetRuntimeAddr(output_tensors_.at(1));
|
||||
}
|
||||
auto *softmax_cross_entropy_param = reinterpret_cast<SoftmaxCrossEntropyParameter *>(parameter_);
|
||||
code.CodeFunction("Softmax", in_str, losses_, sum_data_, "&softmax_params");
|
||||
code.CodeFunction("ForwardPostExecute", labels_str, losses_, grad_str, out_str,
|
||||
softmax_cross_entropy_param->number_of_classes_, softmax_cross_entropy_param->batch_size_);
|
||||
|
||||
context->AppendCode(code.str());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_SoftmaxCrossEntropyWithLogits,
|
||||
CPUOpCoderCreator<SoftmaxCrossEntropyWithLogitsCoder>)
|
||||
} // namespace mindspore::lite::micro::nnacl
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CODER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CODER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "nnacl/softmax_parameter.h"
|
||||
#include "coder/opcoders/op_coder.h"
|
||||
|
||||
namespace mindspore::lite::micro::nnacl {
|
||||
class SoftmaxCrossEntropyWithLogitsCoder final : public OperatorCoder {
|
||||
public:
|
||||
SoftmaxCrossEntropyWithLogitsCoder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const LiteGraph::Node *node, size_t node_index, Target target)
|
||||
: OperatorCoder(in_tensors, out_tensors, node, node_index, target) {}
|
||||
~SoftmaxCrossEntropyWithLogitsCoder() override = default;
|
||||
|
||||
int Prepare(CoderContext *const context) override;
|
||||
|
||||
int DoCode(CoderContext *const context) override;
|
||||
|
||||
private:
|
||||
SoftmaxParameter softmax_params_;
|
||||
float *losses_{nullptr};
|
||||
float *sum_data_{nullptr};
|
||||
};
|
||||
} // namespace mindspore::lite::micro::nnacl
|
||||
|
||||
#endif // MICRO_CODER_OPCODERS_NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CODER_H_
|
|
@ -42,7 +42,8 @@
|
|||
namespace mindspore::lite::micro {
|
||||
CoderSession::CoderSession() { allocator_ = MemoryAllocator::GetInstance(); }
|
||||
|
||||
void CoderSession::EndCode() {
|
||||
int CoderSession::EndCode() {
|
||||
int ret = RET_OK;
|
||||
context_->set_tensor_map(allocator_->tensors_map());
|
||||
context_->set_saved_weights(allocator_->saved_weights());
|
||||
size_t de_quant_max_workspace_size = nnacl::Dequant::GetInstance()->de_quant_max_workspace();
|
||||
|
@ -59,8 +60,10 @@ void CoderSession::EndCode() {
|
|||
context_->set_code_blocks(blocks);
|
||||
}
|
||||
if (config->code_mode() == Train) {
|
||||
Train::TransformGraphForTrain(context_.get(), op_coders_, schema_version_);
|
||||
ret = Train::TransformGraphForTrain(context_.get(), op_coders_, schema_version_);
|
||||
MS_CHECK_RET_CODE(ret, "transform graph for train failed.");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int CoderSession::Run() {
|
||||
|
@ -85,7 +88,8 @@ int CoderSession::Run() {
|
|||
MS_CHECK_RET_CODE(ret, "do coder " << op_coder->name() << " failed");
|
||||
}
|
||||
|
||||
this->EndCode();
|
||||
ret = this->EndCode();
|
||||
MS_CHECK_RET_CODE(ret, "End code failed.");
|
||||
MS_LOG(INFO) << "run opcoders success";
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -220,6 +224,8 @@ int CoderSession::CreateOpCoders() {
|
|||
CodeMode code_mode = config->code_mode();
|
||||
bool support_parallel = config->support_parallel();
|
||||
uint32_t nodes_size = model->graph_.all_nodes_.size();
|
||||
std::vector<lite::Tensor *> all_tensors = coder_graph_->all_tensors();
|
||||
MS_CHECK_TRUE_MSG(!all_tensors.empty(), RET_ERROR, "coder_graph has no any tensors");
|
||||
OpCoderBuilder builder;
|
||||
for (uint32_t i = 0; i < nodes_size; ++i) {
|
||||
const auto *node = model->graph_.all_nodes_.at(i);
|
||||
|
@ -227,19 +233,9 @@ int CoderSession::CreateOpCoders() {
|
|||
MS_LOG(ERROR) << "node is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<lite::Tensor *> all_tensors = coder_graph_->all_tensors();
|
||||
if (all_tensors.empty()) {
|
||||
MS_LOG(ERROR) << "coder_graph has no any tensors";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// set op_coder's inputs && outputs info
|
||||
std::vector<uint32_t> input_indices;
|
||||
std::vector<uint32_t> node_input_indices = node->input_indices_;
|
||||
input_indices.insert(input_indices.end(), node_input_indices.begin(), node_input_indices.end());
|
||||
std::vector<uint32_t> output_indices;
|
||||
std::vector<uint32_t> node_output_indices = node->output_indices_;
|
||||
output_indices.insert(output_indices.end(), node_output_indices.begin(), node_output_indices.end());
|
||||
|
||||
std::vector<uint32_t> input_indices(node->input_indices_);
|
||||
std::vector<uint32_t> output_indices(node->output_indices_);
|
||||
std::vector<lite::Tensor *> inputs;
|
||||
std::vector<lite::Tensor *> outputs;
|
||||
for (auto in_index : input_indices) {
|
||||
|
@ -259,11 +255,11 @@ int CoderSession::CreateOpCoders() {
|
|||
outputs.push_back(all_tensors.at(out_index));
|
||||
}
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(ERROR) << "node: " << node->name_ << "has no inputs tensor";
|
||||
MS_LOG(ERROR) << "node: " << node->name_ << "has no inputs tensor";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(ERROR) << "node: " << node->name_ << "has no outputs tensor";
|
||||
MS_LOG(ERROR) << "node: " << node->name_ << "has no outputs tensor";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ class CoderSession {
|
|||
int CreateOpCoders();
|
||||
int InitCodeGraph();
|
||||
int CompileGraph();
|
||||
void EndCode();
|
||||
int EndCode();
|
||||
|
||||
std::unique_ptr<CoderGraph> coder_graph_{nullptr};
|
||||
std::unique_ptr<CoderContext> context_{nullptr};
|
||||
|
|
|
@ -60,12 +60,13 @@ int Train::TransformGraphForTrain(CoderContext *context, const std::vector<std::
|
|||
MS_LOG(INFO) << "input context invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
const std::array<int, 6> loss_types = {schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits,
|
||||
const std::array<int, 7> loss_types = {schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits,
|
||||
schema::PrimitiveType_BinaryCrossEntropy,
|
||||
schema::PrimitiveType_SmoothL1Loss,
|
||||
schema::PrimitiveType_SmoothL1LossGrad,
|
||||
schema::PrimitiveType_SigmoidCrossEntropyWithLogits,
|
||||
schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad};
|
||||
schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad,
|
||||
schema::PrimitiveType_SoftmaxCrossEntropyWithLogits};
|
||||
OperatorCoder *loss_op = nullptr;
|
||||
for (const auto &opcoder : op_coders) {
|
||||
const LiteGraph::Node *node = opcoder->node();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -112,6 +112,11 @@ std::vector<std::string> AddDumpDataInfo(const std::vector<std::string> &blocks,
|
|||
auto &opcoder = opcoders.at(i);
|
||||
std::string code = blocks.at(i);
|
||||
std::string name = opcoder->name();
|
||||
auto pos = name.find_first_of('/');
|
||||
while (pos != std::string::npos) {
|
||||
name.replace(pos, 1, ".");
|
||||
pos = name.find_first_of('/');
|
||||
}
|
||||
code += " {\n";
|
||||
code += " FILE *output_file = fopen(\"./" + name + ".ir\", \"w\");\n";
|
||||
code += " fprintf(output_file, \"Node:" + name + "\\n\");\n";
|
||||
|
|
Loading…
Reference in New Issue