fixed codex/coverity warnings

This commit is contained in:
Emir Haleva 2021-07-22 17:37:47 +03:00
parent 138f381829
commit 46dec60424
33 changed files with 252 additions and 207 deletions

View File

@ -21,9 +21,9 @@ from train_utils import save_inout, train_wrap
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn from mindspore import context, Tensor, nn
from mindspore.train.serialization import export from mindspore.train.serialization import export
from src.network.densenet import DenseNet121
sys.path.append(os.environ['CLOUD_MODEL_ZOO'] + 'official/cv/densenet121/')
#pylint: disable=wrong-import-position #pylint: disable=wrong-import-position
sys.path.append(os.environ['CLOUD_MODEL_ZOO'] + 'official/cv/densenet121/')
from src.network.densenet import DenseNet121
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False) context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)

View File

@ -15,6 +15,7 @@
"""emoji_model.""" """emoji_model."""
import mindspore as MS import mindspore as MS
class GlobalAvgPooling(MS.nn.Cell): class GlobalAvgPooling(MS.nn.Cell):
""" """
Global avg pooling definition. Global avg pooling definition.
@ -33,6 +34,7 @@ class GlobalAvgPooling(MS.nn.Cell):
x = self.mean(x, (2, 3)) x = self.mean(x, (2, 3))
return x return x
class EmojiModel(MS.nn.Cell): class EmojiModel(MS.nn.Cell):
"""emoji model""" """emoji model"""
def __init__(self, wayc, use_bb, use_head): def __init__(self, wayc, use_bb, use_head):

View File

@ -39,22 +39,8 @@ def save_t(t, file):
x.tofile(file) x.tofile(file)
def save_inout(name, x, l, net, net_train, sparse=False, epoch=1): def train_and_save(name, net, net_train, x, l, epoch):
"""save_inout""" """train_and_save"""
x_name = name + "_input1.bin"
if sparse:
x_name = name + "_input2.bin"
save_t(Tensor(x.asnumpy().transpose(0, 2, 3, 1)), x_name)
l_name = name + "_input2.bin"
if sparse:
l_name = name + "_input1.bin"
save_t(l, l_name)
net.set_train(False)
y = net(x)
#train network
net.set_train(True) net.set_train(True)
for i in range(epoch): for i in range(epoch):
net_train(x, l) net_train(x, l)
@ -72,7 +58,8 @@ def save_inout(name, x, l, net, net_train, sparse=False, epoch=1):
y_name = name + "_output1.bin" y_name = name + "_output1.bin"
save_t(y, y_name) save_t(y, y_name)
def save_inout_transfer(name, x, l, net_bb, net, net_train, sparse=False, epoch=1):
def save_inout(name, x, l, net, net_train, sparse=False, epoch=1):
"""save_inout""" """save_inout"""
x_name = name + "_input1.bin" x_name = name + "_input1.bin"
if sparse: if sparse:
@ -84,25 +71,27 @@ def save_inout_transfer(name, x, l, net_bb, net, net_train, sparse=False, epoch=
l_name = name + "_input1.bin" l_name = name + "_input1.bin"
save_t(l, l_name) save_t(l, l_name)
net.set_train(False)
net(x)
train_and_save(name, net, net_train, x, l, epoch)
def save_inout_transfer(name, x, l, net_bb, net, net_train, sparse=False, epoch=1):
"""save_inout_transfer"""
x_name = name + "_input1.bin"
if sparse:
x_name = name + "_input2.bin"
save_t(Tensor(x.asnumpy().transpose(0, 2, 3, 1)), x_name)
l_name = name + "_input2.bin"
if sparse:
l_name = name + "_input1.bin"
save_t(l, l_name)
net_bb.set_train(False) net_bb.set_train(False)
net.set_train(False)
x1 = net_bb(x) x1 = net_bb(x)
y = net(x1)
#train network
net.set_train(True)
for i in range(epoch):
net_train(x1, l)
net.set_train(False) net.set_train(False)
y = net(x1) net(x1)
if isinstance(y, tuple):
i = 1 train_and_save(name, net, net_train, x1, l, epoch)
for t in y:
with os.fdopen(name + "_output" + str(i) + ".bin", 'w') as f:
for j in t.asnumpy().flatten():
f.write(str(j)+' ')
i = i + 1
else:
y_name = name + "_output1.bin"
save_t(y, y_name)

View File

@ -12,4 +12,3 @@ densenet
shufflenetv2 shufflenetv2
vgg noarm32 vgg noarm32
xception xception
albert_mlm

View File

@ -56,6 +56,8 @@ constexpr int kNCHWCDim = 2;
constexpr int kPrintTimes = 100; constexpr int kPrintTimes = 100;
constexpr int kSaveSteps = 1000; constexpr int kSaveSteps = 1000;
constexpr float kGammaFactor = 0.7f; constexpr float kGammaFactor = 0.7f;
constexpr static int kElem2Print = 10;
class Rescaler : public mindspore::session::TrainLoopCallBack { class Rescaler : public mindspore::session::TrainLoopCallBack {
public: public:
explicit Rescaler(float scale) : scale_(scale) { explicit Rescaler(float scale) : scale_(scale) {
@ -126,7 +128,9 @@ bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inpu
auto d = reinterpret_cast<float *>(after_outputs.at(i)->MutableData()); auto d = reinterpret_cast<float *>(after_outputs.at(i)->MutableData());
int num2p = (after_outputs.at(i)->ElementsNum()); int num2p = (after_outputs.at(i)->ElementsNum());
printf("ou%zu(%d): ", i, num2p); printf("ou%zu(%d): ", i, num2p);
if (num2p > 10) num2p = 10; if (num2p > kElem2Print) {
num2p = kElem2Print;
}
for (int j = 0; j < num2p; j++) printf("%f, ", d[j]); for (int j = 0; j < num2p; j++) printf("%f, ", d[j]);
printf("\n"); printf("\n");
} }

View File

@ -19,7 +19,7 @@ from mindspore import context, Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.train.serialization import export from mindspore.train.serialization import export
from lenet import LeNet5 from lenet import LeNet5
from train_utils import TrainWrap from train_utils import train_wrap
n = LeNet5() n = LeNet5()
n.set_train() n.set_train()
@ -28,7 +28,7 @@ context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU", save_graphs
BATCH_SIZE = 4 BATCH_SIZE = 4
x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32) x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32)
label = Tensor(np.zeros([BATCH_SIZE]).astype(np.int32)) label = Tensor(np.zeros([BATCH_SIZE]).astype(np.int32))
net = TrainWrap(n) net = train_wrap(n)
export(net, x, label, file_name="lenet_tod", file_format='MINDIR') export(net, x, label, file_name="lenet_tod", file_format='MINDIR')
print("finished exporting") print("finished exporting")

View File

@ -17,9 +17,10 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.parameter import ParameterTuple from mindspore.common.parameter import ParameterTuple
def TrainWrap(net, loss_fn=None, optimizer=None, weights=None):
def train_wrap(net, loss_fn=None, optimizer=None, weights=None):
""" """
TrainWrap train_wrap
""" """
if loss_fn is None: if loss_fn is None:
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean', sparse=True) loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean', sparse=True)

View File

@ -199,7 +199,7 @@ int NetRunner::TrainLoop() {
session_->Export(cpkt_fn); session_->Export(cpkt_fn);
} }
std::cout << i + 1 << ": Loss is " << loss << " [min=" << min_loss << "]" << std::endl; std::cout << (i + 1) << ": Loss is " << loss << " [min=" << min_loss << "]" << std::endl;
if ((i + 1) % kBatchNum == 0) { if ((i + 1) % kBatchNum == 0) {
session_->Eval(); session_->Eval();
float acc = CalculateAccuracy(ds_.test_data(), session_); float acc = CalculateAccuracy(ds_.test_data(), session_);

View File

@ -57,6 +57,8 @@ constexpr int kNCHWCDim = 2;
constexpr int kPrintTimes = 100; constexpr int kPrintTimes = 100;
constexpr int kSaveEpochs = 3; constexpr int kSaveEpochs = 3;
constexpr float kGammaFactor = 0.7f; constexpr float kGammaFactor = 0.7f;
constexpr static int kElem2Print = 10;
class Rescaler : public mindspore::TrainCallBack { class Rescaler : public mindspore::TrainCallBack {
public: public:
explicit Rescaler(float scale) : scale_(scale) { explicit Rescaler(float scale) : scale_(scale) {
@ -128,7 +130,9 @@ bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inpu
auto d = reinterpret_cast<float *>(after_outputs.at(i)->MutableData()); auto d = reinterpret_cast<float *>(after_outputs.at(i)->MutableData());
int num2p = (after_outputs.at(i)->ElementsNum()); int num2p = (after_outputs.at(i)->ElementsNum());
printf("ou%zu(%d): ", i, num2p); printf("ou%zu(%d): ", i, num2p);
if (num2p > 10) num2p = 10; if (num2p > kElem2Print) {
num2p = kElem2Print;
}
for (int j = 0; j < num2p; j++) printf("%f, ", d[j]); for (int j = 0; j < num2p; j++) printf("%f, ", d[j]);
printf("\n"); printf("\n");
} }

View File

@ -20,7 +20,6 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, const float *raw_datas, void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, const float *raw_datas,
bool channel_at_first, float *desired_max, float *desired_min) { bool channel_at_first, float *desired_max, float *desired_min) {
float min = FLT_MAX; float min = FLT_MAX;
@ -99,6 +98,5 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
return RET_OK; return RET_OK;
} }
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -23,7 +23,6 @@
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
namespace mindspore { namespace mindspore {
CkptSaver::CkptSaver(int save_every_n, const std::string &filename_prefix) { CkptSaver::CkptSaver(int save_every_n, const std::string &filename_prefix) {
callback_impl_ = new CallbackImpl(new lite::CkptSaver(save_every_n, filename_prefix)); callback_impl_ = new CallbackImpl(new lite::CkptSaver(save_every_n, filename_prefix));
} }
@ -37,5 +36,4 @@ CkptSaver::~CkptSaver() {
delete internal_call_back; delete internal_call_back;
} }
} }
} // namespace mindspore } // namespace mindspore

View File

@ -23,7 +23,6 @@
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
namespace mindspore { namespace mindspore {
LossMonitor::LossMonitor(int print_every_n_steps) { LossMonitor::LossMonitor(int print_every_n_steps) {
callback_impl_ = new CallbackImpl(new lite::LossMonitor(print_every_n_steps)); callback_impl_ = new CallbackImpl(new lite::LossMonitor(print_every_n_steps));
} }
@ -53,5 +52,4 @@ const std::vector<GraphPoint> &LossMonitor::GetLossPoints() {
return (reinterpret_cast<lite::LossMonitor *>(internal_call_back))->GetLossPoints(); return (reinterpret_cast<lite::LossMonitor *>(internal_call_back))->GetLossPoints();
} }
} // namespace mindspore } // namespace mindspore

View File

@ -23,7 +23,6 @@
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
namespace mindspore { namespace mindspore {
int StepLRLambda(float *lr, int epoch, void *lr_cb_data) { int StepLRLambda(float *lr, int epoch, void *lr_cb_data) {
if ((lr == nullptr) || (lr_cb_data == nullptr)) { if ((lr == nullptr) || (lr_cb_data == nullptr)) {
MS_LOG(ERROR) << "nullptr passed as input to MultiplicativeLRLambda"; MS_LOG(ERROR) << "nullptr passed as input to MultiplicativeLRLambda";
@ -51,5 +50,4 @@ LRScheduler::~LRScheduler() {
delete internal_call_back; delete internal_call_back;
} }
} }
} // namespace mindspore } // namespace mindspore

View File

@ -23,7 +23,6 @@
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
namespace mindspore { namespace mindspore {
TrainAccuracy::TrainAccuracy(int print_every_n, int accuracy_metrics, const std::vector<int> &input_indexes, TrainAccuracy::TrainAccuracy(int print_every_n, int accuracy_metrics, const std::vector<int> &input_indexes,
const std::vector<int> &output_indexes) { const std::vector<int> &output_indexes) {
callback_impl_ = new CallbackImpl( callback_impl_ = new CallbackImpl(
@ -55,5 +54,4 @@ const std::vector<GraphPoint> &TrainAccuracy::GetAccuracyPoints() {
return (reinterpret_cast<lite::ClassificationTrainAccuracyMonitor *>(internal_call_back))->GetAccuracyPoints(); return (reinterpret_cast<lite::ClassificationTrainAccuracyMonitor *>(internal_call_back))->GetAccuracyPoints();
} }
} // namespace mindspore } // namespace mindspore

View File

@ -24,6 +24,8 @@
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
namespace mindspore { namespace mindspore {
constexpr static int kMaxNumOfDevices = 2;
Status A2L_ConvertContext(Context *a_context, lite::Context *l_context) { Status A2L_ConvertContext(Context *a_context, lite::Context *l_context) {
if ((a_context == nullptr) || (l_context == nullptr)) { if ((a_context == nullptr) || (l_context == nullptr)) {
MS_LOG(ERROR) << "Invalid context pointers."; MS_LOG(ERROR) << "Invalid context pointers.";
@ -35,7 +37,7 @@ Status A2L_ConvertContext(Context *a_context, lite::Context *l_context) {
MS_LOG(ERROR) << "Invalid device list."; MS_LOG(ERROR) << "Invalid device list.";
return kLiteInputParamInvalid; return kLiteInputParamInvalid;
} }
if (device_list.size() > 2) { if (device_list.size() > kMaxNumOfDevices) {
MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported."; MS_LOG(ERROR) << "Only CPU/CPU & GPU/CPU & NPU mode is supported.";
return kLiteInputParamInvalid; return kLiteInputParamInvalid;
} }
@ -71,7 +73,7 @@ Status A2L_ConvertContext(Context *a_context, lite::Context *l_context) {
cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode}; cpu_info.cpu_device_info_ = {cpu_context->GetEnableFP16(), mode};
l_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(), l_context->device_list_.push_back({lite::DT_CPU, cpu_info, cpu_context->GetProvider(),
cpu_context->GetProviderDevice(), cpu_context->GetAllocator()}); cpu_context->GetProviderDevice(), cpu_context->GetAllocator()});
if (device_list.size() == 2) { if (device_list.size() == kMaxNumOfDevices) {
lite::DeviceInfo device_info = {0}; lite::DeviceInfo device_info = {0};
if (device_list[1]->GetDeviceType() == kGPU) { if (device_list[1]->GetDeviceType() == kGPU) {
auto gpu_context = device_list[1]->Cast<GPUDeviceInfo>(); auto gpu_context = device_list[1]->Cast<GPUDeviceInfo>();

View File

@ -23,7 +23,6 @@
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
namespace mindspore { namespace mindspore {
AccuracyMetrics::AccuracyMetrics(int accuracy_metrics, const std::vector<int> &input_indexes, AccuracyMetrics::AccuracyMetrics(int accuracy_metrics, const std::vector<int> &input_indexes,
const std::vector<int> &output_indexes) { const std::vector<int> &output_indexes) {
metrics_impl_ = new MetricsImpl(new lite::AccuracyMetrics(accuracy_metrics, input_indexes, output_indexes)); metrics_impl_ = new MetricsImpl(new lite::AccuracyMetrics(accuracy_metrics, input_indexes, output_indexes));
@ -56,5 +55,4 @@ float AccuracyMetrics::Eval() {
auto internal_metrics = metrics_impl_->GetInternalMetrics(); auto internal_metrics = metrics_impl_->GetInternalMetrics();
return (reinterpret_cast<lite::AccuracyMetrics *>(internal_metrics))->Eval(); return (reinterpret_cast<lite::AccuracyMetrics *>(internal_metrics))->Eval();
} }
} // namespace mindspore } // namespace mindspore

View File

@ -437,5 +437,4 @@ Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<
auto ret = session_->Resize(inner_input, truncated_shape); auto ret = session_->Resize(inner_input, truncated_shape);
return static_cast<StatusCode>(ret); return static_cast<StatusCode>(ret);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -24,7 +24,6 @@
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
namespace mindspore { namespace mindspore {
Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cfg) { Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cfg) {
if ((a_train_cfg == nullptr) || (l_train_cfg == nullptr)) { if ((a_train_cfg == nullptr) || (l_train_cfg == nullptr)) {
MS_LOG(ERROR) << "Invalid train_cfg pointers"; MS_LOG(ERROR) << "Invalid train_cfg pointers";
@ -39,5 +38,4 @@ Status A2L_ConvertConfig(const TrainCfg *a_train_cfg, lite::TrainCfg *l_train_cf
return kSuccess; return kSuccess;
} }
} // namespace mindspore } // namespace mindspore

View File

@ -40,60 +40,6 @@
#include "src/train/train_session.h" #include "src/train/train_session.h"
namespace mindspore { namespace mindspore {
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
std::shared_ptr<session::LiteSession> CreateTrainSession(std::shared_ptr<Graph::GraphData> graph_data,
std::shared_ptr<TrainCfg> cfg, lite::Context *context) {
bool is_train_session = graph_data->IsTrainModel();
if (is_train_session) {
auto model = graph_data->lite_model();
if (model == nullptr || model->buf == nullptr) {
MS_LOG(ERROR) << "Lite model has been freed.";
return nullptr;
}
std::shared_ptr<session::LiteSession> shared_session;
lite::TrainSession *session = new lite::TrainSession();
if (session == nullptr) {
MS_LOG(ERROR) << "create session failed";
return nullptr;
}
shared_session.reset(session);
lite::TrainCfg train_cfg;
if (cfg != nullptr) {
auto status = A2L_ConvertConfig(cfg.get(), &train_cfg);
if (status != kSuccess) {
MS_LOG(ERROR) << "Failed to convert Config to Lite Config";
return nullptr;
}
}
auto ret = session->Init(context, &train_cfg);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "init session failed";
return nullptr;
}
ret = session->CompileTrainGraph(model);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "Compiling Train Graph session failed";
return nullptr;
}
return shared_session;
}
MS_LOG(DEBUG) << "Session is not a train session.";
return nullptr;
}
class UnifiedAPISupportTrain {
public:
UnifiedAPISupportTrain() { CreateTrainSessionCallbackHolder(CreateTrainSession); }
};
UnifiedAPISupportTrain support_train_api;
Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> *out_ms, Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> *out_ms,
std::vector<session::Metrics *> *adapter_ms) { std::vector<session::Metrics *> *adapter_ms) {
if (out_ms == nullptr || adapter_ms == nullptr) { if (out_ms == nullptr || adapter_ms == nullptr) {
@ -157,5 +103,4 @@ Status ModelImpl::ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i
} }
return kSuccess; return kSuccess;
} }
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,92 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <unordered_map>
#include <algorithm>
#include "include/api/types.h"
#include "include/api/context.h"
#include "include/api/dual_abi_helper.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "include/api/callback/callback.h"
#include "include/api/metrics/metrics.h"
#include "src/lite_model.h"
#include "src/runtime/inner_allocator.h"
#include "src/common/string_util.h"
#include "src/cxx_api/model/model_impl.h"
#include "src/cxx_api/converters.h"
#include "src/cxx_api/graph/graph_data.h"
#include "src/cxx_api/tensor/tensor_impl.h"
#include "src/cxx_api/tensor_utils.h"
#include "src/cxx_api/metrics/metrics_adapter.h"
#include "src/cxx_api/metrics/metrics_impl.h"
#include "src/cxx_api/callback/callback_adapter.h"
#include "src/cxx_api/callback/callback_impl.h"
#include "src/common/log_adapter.h"
#include "src/train/train_session.h"
namespace mindspore {
std::shared_ptr<session::LiteSession> CreateTrainSession(std::shared_ptr<Graph::GraphData> graph_data,
std::shared_ptr<TrainCfg> cfg, lite::Context *context) {
bool is_train_session = graph_data->IsTrainModel();
if (is_train_session) {
auto model = graph_data->lite_model();
if (model == nullptr || model->buf == nullptr) {
MS_LOG(ERROR) << "Lite model has been freed.";
return nullptr;
}
std::shared_ptr<session::LiteSession> shared_session;
lite::TrainSession *session = new lite::TrainSession();
if (session == nullptr) {
MS_LOG(ERROR) << "create session failed";
return nullptr;
}
shared_session.reset(session);
lite::TrainCfg train_cfg;
if (cfg != nullptr) {
auto status = A2L_ConvertConfig(cfg.get(), &train_cfg);
if (status != kSuccess) {
MS_LOG(ERROR) << "Failed to convert Config to Lite Config";
return nullptr;
}
}
auto ret = session->Init(context, &train_cfg);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "init session failed";
return nullptr;
}
ret = session->CompileTrainGraph(model);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "Compiling Train Graph session failed";
return nullptr;
}
return shared_session;
}
MS_LOG(DEBUG) << "Session is not a train session.";
return nullptr;
}
class TrainSupport {
public:
TrainSupport() { CreateTrainSessionCallbackHolder(CreateTrainSession); }
};
TrainSupport support_train_api;
} // namespace mindspore

View File

@ -47,13 +47,11 @@ int RegistryKernelImpl::RegCustomKernel(const std::string &arch, const std::stri
std::unique_lock<std::mutex> lock(lock_); std::unique_lock<std::mutex> lock(lock_);
if (custom_kernel_creators_[provider][arch][type] == nullptr) { if (custom_kernel_creators_[provider][arch][type] == nullptr) {
custom_kernel_creators_[provider][arch][type] = custom_kernel_creators_[provider][arch][type] =
reinterpret_cast<CreateKernel *>(malloc(data_type_length_ * sizeof(CreateKernel))); reinterpret_cast<CreateKernel *>(calloc(data_type_length_, sizeof(CreateKernel)));
if (custom_kernel_creators_[provider][arch][type] == nullptr) { if (custom_kernel_creators_[provider][arch][type] == nullptr) {
MS_LOG(ERROR) << "malloc custom kernel creator fail!provider: " << provider << ", arch: " << arch; MS_LOG(ERROR) << "malloc custom kernel creator fail!provider: " << provider << ", arch: " << arch;
return RET_ERROR; return RET_ERROR;
} }
memset(reinterpret_cast<void *>(custom_kernel_creators_[provider][arch][type]), 0,
data_type_length_ * sizeof(CreateKernel));
} }
int data_type_index = data_type - kNumberTypeBegin - 1; int data_type_index = data_type - kNumberTypeBegin - 1;

View File

@ -24,6 +24,9 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Select; using mindspore::schema::PrimitiveType_Select;
namespace mindspore::kernel { namespace mindspore::kernel {
constexpr static int kFirstIdx = 1;
constexpr static int kSecondIdx = 2;
int SelectCPUKernel::Init() { return RET_OK; } int SelectCPUKernel::Init() { return RET_OK; }
int SelectCPUKernel::ReSize() { return RET_OK; } int SelectCPUKernel::ReSize() { return RET_OK; }
@ -70,8 +73,8 @@ int SelectCPUKernel::Run() {
MS_ASSERT(in_tensors_.at(1)->Size() == out_tensors_.at(0)->Size()); MS_ASSERT(in_tensors_.at(1)->Size() == out_tensors_.at(0)->Size());
auto size = in_tensors_.at(1)->ElementsNum(); auto size = in_tensors_.at(1)->ElementsNum();
auto condition = static_cast<bool *>(bool_tensor->data_c()); auto condition = static_cast<bool *>(bool_tensor->data_c());
auto input1 = static_cast<float *>(in_tensors_.at(1)->data_c()); auto input1 = static_cast<float *>(in_tensors_.at(kFirstIdx)->data_c());
auto input2 = static_cast<float *>(in_tensors_.at(2)->data_c()); auto input2 = static_cast<float *>(in_tensors_.at(kSecondIdx)->data_c());
auto output = static_cast<float *>(out_tensors_.at(0)->data_c()); auto output = static_cast<float *>(out_tensors_.at(0)->data_c());
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
output[i] = condition[i] ? input1[i] : input2[i]; output[i] = condition[i] ? input1[i] : input2[i];

View File

@ -97,7 +97,7 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() {
return RET_ERROR; return RET_ERROR;
} }
} }
void *bias_origin_tmp = IsTrainable() ? in_tensors_.at(2)->data_c() : origin_bias_; void *bias_origin_tmp = IsTrainable() ? in_tensors_.at(kBiasIndex)->data_c() : origin_bias_;
memcpy(bias_data_, bias_origin_tmp, output_channel * sizeof(float16_t)); memcpy(bias_data_, bias_origin_tmp, output_channel * sizeof(float16_t));
memset(reinterpret_cast<char *>(bias_data_) + bias_size, 0, size - bias_size); memset(reinterpret_cast<char *>(bias_data_) + bias_size, 0, size - bias_size);
} }
@ -295,5 +295,4 @@ int Convolution1x1FP16CPUKernel::Eval() {
} }
return InnerKernel::Eval(); return InnerKernel::Eval();
} }
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -25,6 +25,15 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_FusedBatchNorm; using mindspore::schema::PrimitiveType_FusedBatchNorm;
namespace mindspore::kernel { namespace mindspore::kernel {
constexpr static int kInScaleIdx = 1;
constexpr static int kInOffsetIdx = 2;
constexpr static int kInCurrentMeanIdx = 3;
constexpr static int kInCurrentVarIdx = 4;
constexpr static int kMaxInIdx = 5;
constexpr static int kOutScaleIdx = 1;
constexpr static int kOutOffsetIdx = 2;
constexpr static int kOutCurrentMeanIdx = 3;
constexpr static int kOutCurrentVarIdx = 4;
void FusedBatchnormFp16CPUKernel::CalcMeanVar(float16_t *in, float16_t *scale, float16_t *offset, float16_t *save_mean, void FusedBatchnormFp16CPUKernel::CalcMeanVar(float16_t *in, float16_t *scale, float16_t *offset, float16_t *save_mean,
float16_t *save_variance) { float16_t *save_variance) {
@ -32,18 +41,18 @@ void FusedBatchnormFp16CPUKernel::CalcMeanVar(float16_t *in, float16_t *scale, f
float16_t *current_mean = static_cast<float16_t *>(mean_); float16_t *current_mean = static_cast<float16_t *>(mean_);
float16_t *current_var = static_cast<float16_t *>(variance_); float16_t *current_var = static_cast<float16_t *>(variance_);
std::fill(current_mean, current_mean + in_tensors_.at(3)->ElementsNum(), 0.f); std::fill(current_mean, current_mean + in_tensors_.at(kInCurrentMeanIdx)->ElementsNum(), 0.f);
std::fill(current_var, current_var + in_tensors_.at(4)->ElementsNum(), 0.f); std::fill(current_var, current_var + in_tensors_.at(kInCurrentVarIdx)->ElementsNum(), 0.f);
FusedBatchNormFp16MeanVar(in, current_mean, current_var, param, save_mean, save_variance); FusedBatchNormFp16MeanVar(in, current_mean, current_var, param, save_mean, save_variance);
memcpy(out_tensors_.at(1)->data_c(), scale, out_tensors_.at(1)->Size()); memcpy(out_tensors_.at(kOutScaleIdx)->data_c(), scale, out_tensors_.at(kOutScaleIdx)->Size());
memcpy(out_tensors_.at(2)->data_c(), offset, out_tensors_.at(2)->Size()); memcpy(out_tensors_.at(kOutOffsetIdx)->data_c(), offset, out_tensors_.at(kOutOffsetIdx)->Size());
memcpy(out_tensors_.at(3)->data_c(), current_mean, out_tensors_.at(3)->Size()); memcpy(out_tensors_.at(kOutCurrentMeanIdx)->data_c(), current_mean, out_tensors_.at(kOutCurrentMeanIdx)->Size());
memcpy(out_tensors_.at(4)->data_c(), current_var, out_tensors_.at(4)->Size()); memcpy(out_tensors_.at(kOutCurrentVarIdx)->data_c(), current_var, out_tensors_.at(kOutCurrentVarIdx)->Size());
// Copy to local variables // Copy to local variables
memcpy(scale_, scale, in_tensors_[1]->Size()); memcpy(scale_, scale, in_tensors_[kInScaleIdx]->Size());
memcpy(offset_, offset, in_tensors_[2]->Size()); memcpy(offset_, offset, in_tensors_[kInOffsetIdx]->Size());
trained_ = true; // trained at least once trained_ = true; // trained at least once
} }
@ -52,13 +61,13 @@ int FusedBatchnormFp16CPUKernel::DoExecute(int task_id) {
auto param = reinterpret_cast<BatchNormParameter *>(op_parameter_); auto param = reinterpret_cast<BatchNormParameter *>(op_parameter_);
MS_ASSERT(param); MS_ASSERT(param);
if (in_tensors_.at(0)->data_type() == kNumberTypeFloat32) { if (in_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
MS_ASSERT(in_tensors_.size() == 5); MS_ASSERT(in_tensors_.size() == kMaxInIdx);
MS_ASSERT(out_tensors_.size() == 1); MS_ASSERT(out_tensors_.size() == 1);
auto input = in_tensors_.at(0); auto input = in_tensors_.at(0);
auto scale = in_tensors_.at(1); auto scale = in_tensors_.at(kInScaleIdx);
auto offset = in_tensors_.at(2); auto offset = in_tensors_.at(kInOffsetIdx);
auto mean = in_tensors_.at(3); auto mean = in_tensors_.at(kInCurrentMeanIdx);
auto variance = in_tensors_.at(4); auto variance = in_tensors_.at(kInCurrentVarIdx);
auto output = out_tensors_.at(0); auto output = out_tensors_.at(0);
auto input_fp16 = ms_context_->allocator->Malloc(input->ElementsNum() * sizeof(float16_t)); auto input_fp16 = ms_context_->allocator->Malloc(input->ElementsNum() * sizeof(float16_t));
@ -88,7 +97,7 @@ int FusedBatchnormFp16CPUKernel::DoExecute(int task_id) {
Float32ToFloat16(reinterpret_cast<float *>(variance->data_c()), reinterpret_cast<float16_t *>(variance_fp16), Float32ToFloat16(reinterpret_cast<float *>(variance->data_c()), reinterpret_cast<float16_t *>(variance_fp16),
variance->ElementsNum()); variance->ElementsNum());
if (IsTrain() && IsTrainable() && in_tensors_.size() >= 5) { if (IsTrain() && IsTrainable() && in_tensors_.size() >= kMaxInIdx) {
CalcMeanVar(reinterpret_cast<float16_t *>(input_fp16), reinterpret_cast<float16_t *>(scale_fp16), CalcMeanVar(reinterpret_cast<float16_t *>(input_fp16), reinterpret_cast<float16_t *>(scale_fp16),
reinterpret_cast<float16_t *>(offset_fp16), reinterpret_cast<float16_t *>(mean_fp16), reinterpret_cast<float16_t *>(offset_fp16), reinterpret_cast<float16_t *>(mean_fp16),
reinterpret_cast<float16_t *>(variance_fp16)); reinterpret_cast<float16_t *>(variance_fp16));
@ -108,11 +117,12 @@ int FusedBatchnormFp16CPUKernel::DoExecute(int task_id) {
return RET_OK; return RET_OK;
} }
if (IsTrain() && IsTrainable() && in_tensors_.size() >= 5) { if (IsTrain() && IsTrainable() && in_tensors_.size() >= kMaxInIdx) {
CalcMeanVar( CalcMeanVar(static_cast<float16_t *>(in_tensors_.at(0)->data_c()),
static_cast<float16_t *>(in_tensors_.at(0)->data_c()), static_cast<float16_t *>(in_tensors_.at(1)->data_c()), static_cast<float16_t *>(in_tensors_.at(kInScaleIdx)->data_c()),
static_cast<float16_t *>(in_tensors_.at(2)->data_c()), static_cast<float16_t *>(in_tensors_.at(3)->data_c()), static_cast<float16_t *>(in_tensors_.at(kInOffsetIdx)->data_c()),
static_cast<float16_t *>(in_tensors_.at(4)->data_c())); static_cast<float16_t *>(in_tensors_.at(kInCurrentMeanIdx)->data_c()),
static_cast<float16_t *>(in_tensors_.at(kInCurrentVarIdx)->data_c()));
} }
FusedBatchNormFp16(in_tensors_.at(0)->data_c(), scale_, offset_, mean_, variance_, param, task_id, FusedBatchNormFp16(in_tensors_.at(0)->data_c(), scale_, offset_, mean_, variance_, param, task_id,
out_tensors_.at(0)->data_c()); out_tensors_.at(0)->data_c());
@ -122,16 +132,16 @@ int FusedBatchnormFp16CPUKernel::DoExecute(int task_id) {
int FusedBatchnormFp16CPUKernel::Eval() { int FusedBatchnormFp16CPUKernel::Eval() {
InnerKernel::Eval(); InnerKernel::Eval();
if (trained_) { if (trained_) {
float16_t *save_mean = static_cast<float16_t *>(in_tensors_.at(3)->data_c()); float16_t *save_mean = static_cast<float16_t *>(in_tensors_.at(kInCurrentMeanIdx)->data_c());
float16_t *save_var = static_cast<float16_t *>(in_tensors_.at(4)->data_c()); float16_t *save_var = static_cast<float16_t *>(in_tensors_.at(kInCurrentVarIdx)->data_c());
float16_t *scale = static_cast<float16_t *>(in_tensors_.at(1)->data_c()); float16_t *scale = static_cast<float16_t *>(in_tensors_.at(kInScaleIdx)->data_c());
float16_t *bias = static_cast<float16_t *>(in_tensors_.at(2)->data_c()); float16_t *bias = static_cast<float16_t *>(in_tensors_.at(kInOffsetIdx)->data_c());
// Copy to local variables // Copy to local variables
memcpy(scale_, scale, in_tensors_.at(1)->Size()); memcpy(scale_, scale, in_tensors_.at(kInScaleIdx)->Size());
memcpy(offset_, bias, in_tensors_.at(2)->Size()); memcpy(offset_, bias, in_tensors_.at(kInOffsetIdx)->Size());
memcpy(mean_, save_mean, in_tensors_.at(3)->Size()); memcpy(mean_, save_mean, in_tensors_.at(kInCurrentMeanIdx)->Size());
memcpy(variance_, save_var, in_tensors_.at(4)->Size()); memcpy(variance_, save_var, in_tensors_.at(kInCurrentVarIdx)->Size());
} }
return RET_OK; return RET_OK;
} }

View File

@ -26,14 +26,17 @@ using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
namespace mindspore::kernel { namespace mindspore::kernel {
constexpr static int kX1Idx = 0;
constexpr static int kX2Idx = 1;
constexpr static int kDyIdx = 2;
int ArithmeticGradCPUKernelFp16::Init() { return RET_OK; } int ArithmeticGradCPUKernelFp16::Init() { return RET_OK; }
void ArithmeticGradCPUKernelFp16::ArithmeticGradMaximum(float16_t *dy, int dy_size, float16_t *dx1, int dx1_size, void ArithmeticGradCPUKernelFp16::ArithmeticGradMaximum(float16_t *dy, int dy_size, float16_t *dx1, int dx1_size,
float16_t *dx2, int dx2_size) { float16_t *dx2, int dx2_size) {
auto x1 = reinterpret_cast<float16_t *>(in_tensors_[0]->data_c()); auto x1 = reinterpret_cast<float16_t *>(in_tensors_[kX1Idx]->data_c());
auto x2 = reinterpret_cast<float16_t *>(in_tensors_[1]->data_c()); auto x2 = reinterpret_cast<float16_t *>(in_tensors_[kX2Idx]->data_c());
dy = reinterpret_cast<float16_t *>(in_tensors_[2]->data_c()); dy = reinterpret_cast<float16_t *>(in_tensors_[kDyIdx]->data_c());
MaximumByAxesFp16(x1, x2, dy, arithmeticParameter_->in_shape0_, arithmeticParameter_->in_shape1_, MaximumByAxesFp16(x1, x2, dy, arithmeticParameter_->in_shape0_, arithmeticParameter_->in_shape1_,
arithmeticParameter_->out_shape_, dx1, dx2, arithmeticParameter_->ndim_); arithmeticParameter_->out_shape_, dx1, dx2, arithmeticParameter_->ndim_);
@ -41,9 +44,9 @@ void ArithmeticGradCPUKernelFp16::ArithmeticGradMaximum(float16_t *dy, int dy_si
void ArithmeticGradCPUKernelFp16::ArithmeticGradMinimum(float16_t *dy, int dy_size, float16_t *dx1, int dx1_size, void ArithmeticGradCPUKernelFp16::ArithmeticGradMinimum(float16_t *dy, int dy_size, float16_t *dx1, int dx1_size,
float16_t *dx2, int dx2_size) { float16_t *dx2, int dx2_size) {
auto x1 = reinterpret_cast<float16_t *>(in_tensors_[0]->data_c()); auto x1 = reinterpret_cast<float16_t *>(in_tensors_[kX1Idx]->data_c());
auto x2 = reinterpret_cast<float16_t *>(in_tensors_[1]->data_c()); auto x2 = reinterpret_cast<float16_t *>(in_tensors_[kX2Idx]->data_c());
dy = reinterpret_cast<float16_t *>(in_tensors_[2]->data_c()); dy = reinterpret_cast<float16_t *>(in_tensors_[kDyIdx]->data_c());
MinimumByAxesFp16(x1, x2, dy, arithmeticParameter_->in_shape0_, arithmeticParameter_->in_shape1_, MinimumByAxesFp16(x1, x2, dy, arithmeticParameter_->in_shape0_, arithmeticParameter_->in_shape1_,
arithmeticParameter_->out_shape_, dx1, dx2, arithmeticParameter_->ndim_); arithmeticParameter_->out_shape_, dx1, dx2, arithmeticParameter_->ndim_);

View File

@ -27,6 +27,7 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_BiasAddGrad; using mindspore::schema::PrimitiveType_BiasAddGrad;
namespace mindspore::kernel { namespace mindspore::kernel {
constexpr static int kMaxDim = 4;
int BiasGradCPUKernelFp16::ReSize() { int BiasGradCPUKernelFp16::ReSize() {
auto dims = in_tensors_[0]->shape(); auto dims = in_tensors_[0]->shape();
@ -36,7 +37,7 @@ int BiasGradCPUKernelFp16::ReSize() {
bias_param->out_shape_[i] = 1; // 1 dimension for N,H,W, bias_param->out_shape_[i] = 1; // 1 dimension for N,H,W,
} }
bias_param->out_shape_[bias_param->ndim_ - 1] = dims[bias_param->ndim_ - 1]; bias_param->out_shape_[bias_param->ndim_ - 1] = dims[bias_param->ndim_ - 1];
for (auto i = bias_param->ndim_; i < 4; i++) { for (auto i = bias_param->ndim_; i < kMaxDim; i++) {
bias_param->in_shape0_[i] = 0; bias_param->in_shape0_[i] = 0;
bias_param->out_shape_[i] = 0; bias_param->out_shape_[i] = 0;
} }

View File

@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#define _STUB
#include "src/train/train_export.h" #include "src/train/train_export.h"
#include <unistd.h> #include <unistd.h>
#include <sys/stat.h> #include <sys/stat.h>
@ -29,6 +28,8 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
constexpr static int kFmkVal = 3;
constexpr static int kTransformTensorDim = 4;
std::vector<uint8_t> TrainExport::CreateData(const lite::Tensor *tensor) { std::vector<uint8_t> TrainExport::CreateData(const lite::Tensor *tensor) {
uint8_t *tensor_data = reinterpret_cast<uint8_t *>(tensor->data_c()); uint8_t *tensor_data = reinterpret_cast<uint8_t *>(tensor->data_c());
@ -209,7 +210,7 @@ std::unique_ptr<schema::TensorT> TrainExport::CreateTransformTensor(size_t id) {
tensorT->dataType = scTensor->dataType; tensorT->dataType = scTensor->dataType;
std::vector<int32_t> dims; std::vector<int32_t> dims;
std::vector<int32_t> val = {0, 2, 3, 1}; std::vector<int32_t> val = {0, 2, 3, 1};
if (scTensor->dims.size() == 4) { if (scTensor->dims.size() == kTransformTensorDim) {
for (size_t i = 0; i < val.size(); i++) { for (size_t i = 0; i < val.size(); i++) {
dims.push_back(scTensor->dims.at(val[i])); dims.push_back(scTensor->dims.at(val[i]));
} }
@ -233,7 +234,7 @@ std::unique_ptr<schema::TensorT> TrainExport::CreateTransformConst(size_t last_i
} }
tensorT->nodeType = lite::NodeType_ValueNode; tensorT->nodeType = lite::NodeType_ValueNode;
tensorT->dataType = TypeId::kNumberTypeInt32; tensorT->dataType = TypeId::kNumberTypeInt32;
tensorT->dims = {4}; tensorT->dims = {kTransformTensorDim};
tensorT->format = schema::Format_NCHW; tensorT->format = schema::Format_NCHW;
tensorT->name = "const-" + std::to_string(last_id); tensorT->name = "const-" + std::to_string(last_id);
tensorT->refCount = 0; tensorT->refCount = 0;
@ -406,7 +407,7 @@ int TrainExport::ExportInit(const std::string model_name, std::string version) {
MS_LOG(ERROR) << "cannot allocate meta_graph"; MS_LOG(ERROR) << "cannot allocate meta_graph";
return RET_ERROR; return RET_ERROR;
} }
meta_graph_->fmkType = 3; meta_graph_->fmkType = kFmkVal;
meta_graph_->name = model_name; meta_graph_->name = model_name;
meta_graph_->version = version; meta_graph_->version = version;
return RET_OK; return RET_OK;
@ -420,6 +421,5 @@ int TrainExport::IsInputTensor(const schema::TensorT &t) {
} }
TrainExport::~TrainExport() { delete meta_graph_; } TrainExport::~TrainExport() { delete meta_graph_; }
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -31,7 +31,6 @@
#include "nnacl/fp32_grad/resize_grad.h" #include "nnacl/fp32_grad/resize_grad.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
OpParameter *PopulateSmoothL1LossParameter(const void *prim) { OpParameter *PopulateSmoothL1LossParameter(const void *prim) {
SmoothL1LossParameter *p = reinterpret_cast<SmoothL1LossParameter *>(malloc(sizeof(SmoothL1LossParameter))); SmoothL1LossParameter *p = reinterpret_cast<SmoothL1LossParameter *>(malloc(sizeof(SmoothL1LossParameter)));
if (p == nullptr) { if (p == nullptr) {

View File

@ -198,7 +198,7 @@ TrainSession::~TrainSession() { FreeWorkSpace(); }
int TrainSession::ExecKernels(const KernelCallBack &before, const KernelCallBack &after, int TrainSession::ExecKernels(const KernelCallBack &before, const KernelCallBack &after,
const std::vector<kernel::LiteKernel *> &run_kernels) { const std::vector<kernel::LiteKernel *> &run_kernels) {
for (auto *kernel : run_kernels) { for (auto *kernel : run_kernels) {
MS_ASSERT(nullptr != kernel); MS_ASSERT(kernel != nullptr);
auto ret = kernel->Execute(before, after); auto ret = kernel->Execute(before, after);
if (RET_OK != ret) { if (RET_OK != ret) {
MS_LOG(ERROR) << "Execute kernel failed, name: " << kernel->name(); MS_LOG(ERROR) << "Execute kernel failed, name: " << kernel->name();
@ -309,7 +309,7 @@ int TrainSession::MixPrecisionExecKernels(const KernelCallBack &before, const Ke
const std::vector<kernel::LiteKernel *> &run_kernels) { const std::vector<kernel::LiteKernel *> &run_kernels) {
float scale = cfg_.mix_precision_cfg_.loss_scale_; float scale = cfg_.mix_precision_cfg_.loss_scale_;
for (auto *kernel : run_kernels) { for (auto *kernel : run_kernels) {
MS_ASSERT(nullptr != kernel); MS_ASSERT(kernel != nullptr);
MixPrecisionPreProcess(kernel, scale); MixPrecisionPreProcess(kernel, scale);
auto ret = kernel->Execute(before, after); auto ret = kernel->Execute(before, after);
if (RET_OK != ret) { if (RET_OK != ret) {
@ -398,7 +398,7 @@ int TrainSession::Train() {
train_mode_ = true; train_mode_ = true;
virtual_batch_idx_ = 0; virtual_batch_idx_ = 0;
for (auto &kernel : this->train_kernels_) { for (auto &kernel : this->train_kernels_) {
MS_ASSERT(nullptr != kernel); MS_ASSERT(kernel != nullptr);
auto ret = kernel->Train(); auto ret = kernel->Train();
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << kernel->name() << " failed to set train mode"; MS_LOG(ERROR) << kernel->name() << " failed to set train mode";
@ -791,5 +791,4 @@ session::LiteSession *session::TrainSession::CreateTrainSession(const std::strin
} }
return session.release(); return session.release();
} }
} // namespace mindspore } // namespace mindspore

View File

@ -26,7 +26,6 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor *searchParameter) { size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor *searchParameter) {
for (size_t i = 0; i < where.size(); i++) { for (size_t i = 0; i < where.size(); i++) {
if (where[i] == searchParameter) { if (where[i] == searchParameter) {
@ -199,6 +198,5 @@ int ScaleTensor(Tensor *tensor, float scale) {
MS_LOG(DEBUG) << "Scale tensor: " << tensor->tensor_name() << " " << scale; MS_LOG(DEBUG) << "Scale tensor: " << tensor->tensor_name() << " " << scale;
return tensor->Scale<float>(scale); return tensor->Scale<float>(scale);
} }
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -39,7 +39,6 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
TransferSession::TransferSession(const char *model_buf_backbone, size_t size_backbone, const lite::Context *context) TransferSession::TransferSession(const char *model_buf_backbone, size_t size_backbone, const lite::Context *context)
: is_valid_(false) { : is_valid_(false) {
lite_model_ = reinterpret_cast<char *>(malloc(size_backbone)); lite_model_ = reinterpret_cast<char *>(malloc(size_backbone));
@ -92,18 +91,18 @@ int TransferSession::CompileTransferGraph() {
nchw2nhwc_ = CompileFormatTransform(output, input, nchw2nhwc_mask, 4); nchw2nhwc_ = CompileFormatTransform(output, input, nchw2nhwc_mask, 4);
match = nchw2nhwc_; match = nchw2nhwc_;
} }
if (true == match) { if (match) {
break; break;
} }
} }
} }
if (true == match) { if (match) {
backbone_head_map_.push_back(std::make_pair(input, output)); backbone_head_map_.push_back(std::make_pair(input, output));
} else { } else {
combined_inputs_.push_back(input); combined_inputs_.push_back(input);
} }
} }
if (0 == backbone_head_map_.size()) { if (backbone_head_map_.size() == 0) {
ret = RET_ERROR; ret = RET_ERROR;
} }
return ret; return ret;
@ -113,7 +112,7 @@ mindspore::tensor::MSTensor *TransferSession::GetInputsByTensorName(const std::s
/* First look in backbone netwok */ /* First look in backbone netwok */
auto ret = backbone_session_->GetInputsByTensorName(tensor_name); auto ret = backbone_session_->GetInputsByTensorName(tensor_name);
/* If not found look in head network */ /* If not found look in head network */
if (nullptr == ret) { if (ret == nullptr) {
ret = TrainSession::GetInputsByTensorName(tensor_name); ret = TrainSession::GetInputsByTensorName(tensor_name);
} }
return ret; return ret;
@ -220,7 +219,6 @@ int TransferSession::Export(const std::string &filename, ModelType model_type, Q
if (orig_train_state) Train(); if (orig_train_state) Train();
return status; return status;
} }
} // namespace lite } // namespace lite
static session::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone, static session::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone,
@ -316,5 +314,4 @@ session::LiteSession *session::TrainSession::CreateTransferSession(const std::st
} }
return CreateTransferSessionInt(buf_backbone, size_backbone, buf_head, size_head, ctxt, train_mode, cfg); return CreateTransferSessionInt(buf_backbone, size_backbone, buf_head, size_head, ctxt, train_mode, cfg);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -38,6 +38,15 @@ static const char *DELIM_SLASH = "/";
constexpr const char *DELIM_COLON = ":"; constexpr const char *DELIM_COLON = ":";
constexpr const char *DELIM_COMMA = ","; constexpr const char *DELIM_COMMA = ",";
constexpr int RET_TOO_BIG = -9; constexpr int RET_TOO_BIG = -9;
constexpr int kField0 = 0;
constexpr int kField1 = 1;
constexpr int kField2 = 2;
constexpr int kField3 = 3;
constexpr int kField4 = 4;
constexpr int kFieldsToPrint = 5;
constexpr int kPrintOffset = 4;
constexpr int kCPUBindFlag2 = 2;
constexpr int kCPUBindFlag1 = 1;
namespace { namespace {
float *ReadFileBuf(const char *file, size_t *size) { float *ReadFileBuf(const char *file, size_t *size) {
@ -60,7 +69,7 @@ float *ReadFileBuf(const char *file, size_t *size) {
ifs.seekg(0, std::ios::end); ifs.seekg(0, std::ios::end);
*size = ifs.tellg(); *size = ifs.tellg();
std::unique_ptr<float[]> buf((new (std::nothrow) float[*size / sizeof(float) + 1])); std::unique_ptr<float[]> buf = std::make_unique<float[]>(*size / sizeof(float) + 1);
if (buf == nullptr) { if (buf == nullptr) {
MS_LOG(ERROR) << "malloc buf failed, file: " << real_path; MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
ifs.close(); ifs.close();
@ -136,7 +145,7 @@ int NetTrain::ReadInputFile(std::vector<mindspore::tensor::MSTensor *> *ms_input
MS_ASSERT(cur_tensor != nullptr); MS_ASSERT(cur_tensor != nullptr);
size_t size; size_t size;
std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin"; std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin";
char *bin_buf = ReadFile(file_name.c_str(), &size); auto bin_buf = ReadFile(file_name.c_str(), &size);
if (bin_buf == nullptr) { if (bin_buf == nullptr) {
MS_LOG(ERROR) << "ReadFile return nullptr"; MS_LOG(ERROR) << "ReadFile return nullptr";
return RET_ERROR; return RET_ERROR;
@ -312,10 +321,10 @@ int NetTrain::MarkAccuracy(const std::unique_ptr<session::LiteSession> &session,
} }
static CpuBindMode FlagToBindMode(int flag) { static CpuBindMode FlagToBindMode(int flag) {
if (flag == 2) { if (flag == kCPUBindFlag2) {
return MID_CPU; return MID_CPU;
} }
if (flag == 1) { if (flag == kCPUBindFlag1) {
return HIGHER_CPU; return HIGHER_CPU;
} }
return NO_BIND; return NO_BIND;
@ -337,7 +346,6 @@ std::unique_ptr<session::LiteSession> NetTrain::CreateAndRunNetworkForTrain(cons
std::cout << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str() << std::endl; std::cout << "RunNetTrain CreateTranferSession failed while running " << model_name.c_str() << std::endl;
return nullptr; return nullptr;
} }
} else { } else {
MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str(); MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str();
std::cout << "CreateTrainSession from model file " << filename.c_str() << std::endl; std::cout << "CreateTrainSession from model file " << filename.c_str() << std::endl;
@ -428,8 +436,8 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string
} }
auto end_prepare_time = GetTimeUs(); auto end_prepare_time = GetTimeUs();
MS_LOG(INFO) << "PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms"; MS_LOG(INFO) << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / 1000) << " ms";
std::cout << "PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms" << std::endl; std::cout << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / 1000) << " ms" << std::endl;
// Load input // Load input
MS_LOG(INFO) << "Load input data"; MS_LOG(INFO) << "Load input data";
auto ms_inputs = session->GetInputs(); auto ms_inputs = session->GetInputs();
@ -719,43 +727,47 @@ int NetTrain::Init() {
int NetTrain::PrintResult(const std::vector<std::string> &title, int NetTrain::PrintResult(const std::vector<std::string> &title,
const std::map<std::string, std::pair<int, float>> &result) { const std::map<std::string, std::pair<int, float>> &result) {
std::vector<size_t> columnLenMax(5); std::vector<size_t> columnLenMax(kFieldsToPrint);
std::vector<std::vector<std::string>> rows; std::vector<std::vector<std::string>> rows;
for (auto &iter : result) { for (auto &iter : result) {
char stringBuf[5][100] = {}; std::string stringBuf[kFieldsToPrint];
std::vector<std::string> columns; std::vector<std::string> columns;
size_t len; size_t len;
len = iter.first.size(); len = iter.first.size();
if (len > columnLenMax.at(0)) { if (len > columnLenMax.at(kField0)) {
columnLenMax.at(0) = len + 4; columnLenMax.at(kField0) = len + kPrintOffset;
} }
columns.push_back(iter.first); columns.push_back(iter.first);
len = snprintf(stringBuf[1], sizeof(stringBuf[1]), "%f", iter.second.second / flags_->epochs_); stringBuf[kField1] = to_string(iter.second.second / flags_->epochs_);
if (len > columnLenMax.at(1)) { len = stringBuf[kField1].length();
columnLenMax.at(1) = len + 4; if (len > columnLenMax.at(kField1)) {
columnLenMax.at(kField1) = len + kPrintOffset;
} }
columns.emplace_back(stringBuf[1]); columns.emplace_back(stringBuf[kField1]);
len = snprintf(stringBuf[2], sizeof(stringBuf[2]), "%f", iter.second.second / op_cost_total_); stringBuf[kField2] = to_string(iter.second.second / op_cost_total_);
if (len > columnLenMax.at(2)) { len = stringBuf[kField2].length();
columnLenMax.at(2) = len + 4; if (len > columnLenMax.at(kField2)) {
columnLenMax.at(kField2) = len + kPrintOffset;
} }
columns.emplace_back(stringBuf[2]); columns.emplace_back(stringBuf[kField2]);
len = snprintf(stringBuf[3], sizeof(stringBuf[3]), "%d", iter.second.first); stringBuf[kField3] = to_string(iter.second.first);
if (len > columnLenMax.at(3)) { len = stringBuf[kField3].length();
columnLenMax.at(3) = len + 4; if (len > columnLenMax.at(kField3)) {
columnLenMax.at(kField3) = len + kPrintOffset;
} }
columns.emplace_back(stringBuf[3]); columns.emplace_back(stringBuf[kField3]);
len = snprintf(stringBuf[4], sizeof(stringBuf[4]), "%f", iter.second.second); stringBuf[kField4] = to_string(iter.second.second);
if (len > columnLenMax.at(4)) { len = stringBuf[kField4].length();
columnLenMax.at(4) = len + 4; if (len > columnLenMax.at(kField4)) {
columnLenMax.at(kField4) = len + kPrintOffset;
} }
columns.emplace_back(stringBuf[4]); columns.emplace_back(stringBuf[kField4]);
rows.push_back(columns); rows.push_back(columns);
} }

View File

@ -55,6 +55,9 @@ using std::vector;
namespace mindspore::lite::quant { namespace mindspore::lite::quant {
const std::vector<std::string> QuantStrategy::conv_types_ = {ops::kNameConv2DFusion, ops::kNameConv2dTransposeFusion}; const std::vector<std::string> QuantStrategy::conv_types_ = {ops::kNameConv2DFusion, ops::kNameConv2dTransposeFusion};
const std::vector<std::string> QuantStrategy::mul_types_ = {ops::kNameMatMul, ops::kNameFullConnection}; const std::vector<std::string> QuantStrategy::mul_types_ = {ops::kNameMatMul, ops::kNameFullConnection};
constexpr int kDim2 = 2;
constexpr int kDim4 = 4;
QuantStrategy::QuantStrategy(size_t weight_size, size_t conv_weight_quant_channel_threshold) QuantStrategy::QuantStrategy(size_t weight_size, size_t conv_weight_quant_channel_threshold)
: m_weight_size_(weight_size), m_conv_weight_quant_channel_threshold_(conv_weight_quant_channel_threshold) {} : m_weight_size_(weight_size), m_conv_weight_quant_channel_threshold_(conv_weight_quant_channel_threshold) {}
@ -209,7 +212,7 @@ bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &inputNode) const {
} }
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape(); auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
if (weight_shape.size() < 2) { // do not quant single dim tensors if (weight_shape.size() < kDim2) { // do not quant single dim tensors
return false; return false;
} }
@ -222,7 +225,7 @@ bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &inputNode) const {
return false; return false;
} }
if (weight_shape.size() == 4) { // assume Convolution if (weight_shape.size() == kDim4) { // assume Convolution
if (weight_shape[0] <= static_cast<int>(m_conv_weight_quant_channel_threshold_)) { if (weight_shape[0] <= static_cast<int>(m_conv_weight_quant_channel_threshold_)) {
MS_LOG(INFO) << "channel less m_conv_weight_quant_channel_threshold_!" << weight_shape[0]; MS_LOG(INFO) << "channel less m_conv_weight_quant_channel_threshold_!" << weight_shape[0];
return false; return false;