forked from mindspore-Ecosystem/mindspore
!13687 [MSLITE][TOD] Fix Issues. Support BatchNormGrad for CPU
From: @yonibaehr_admin Reviewed-by: @HilbertDavid,@hangangqiang Signed-off-by: @HilbertDavid
This commit is contained in:
commit
35297745f5
|
@ -28,7 +28,7 @@ def TrainWrap(net, loss_fn=None, optimizer=None, weights=None):
|
|||
if weights is None:
|
||||
weights = ParameterTuple(net.trainable_params())
|
||||
if optimizer is None:
|
||||
optimizer = nn.Adam(weights, learning_rate=1e-2, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
|
||||
use_nesterov=False, weight_decay=0.0, loss_scale=1.0)
|
||||
optimizer = nn.Adam(weights, learning_rate=0.003, beta1=0.9, beta2=0.999, eps=1e-5, use_locking=False,
|
||||
use_nesterov=False, weight_decay=4e-5, loss_scale=1.0)
|
||||
train_net = nn.TrainOneStepCell(loss_net, optimizer)
|
||||
return train_net
|
||||
|
|
|
@ -159,7 +159,7 @@ int NetRunner::InitDB() {
|
|||
}
|
||||
|
||||
int NetRunner::TrainLoop() {
|
||||
struct mindspore::lite::StepLRLambda step_lr_lambda(1, 0.8);
|
||||
struct mindspore::lite::StepLRLambda step_lr_lambda(1, 0.7);
|
||||
mindspore::lite::LRScheduler step_lr_sched(mindspore::lite::StepLRLambda, static_cast<void *>(&step_lr_lambda), 1);
|
||||
|
||||
mindspore::lite::LossMonitor lm(100);
|
||||
|
|
|
@ -8,7 +8,7 @@ fi
|
|||
echo "============Exporting=========="
|
||||
if [ -n "$1" ]; then
|
||||
DOCKER_IMG=$1
|
||||
rm *.so*
|
||||
rm -f *.so*
|
||||
docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER_IMG} /bin/bash -c "python transfer_learning_export.py; chmod 444 transfer_learning_tod*.mindir; rm -rf __pycache__"
|
||||
else
|
||||
echo "MindSpore docker was not provided, attempting to run locally"
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <math.h>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <fstream>
|
||||
|
||||
|
@ -51,6 +51,12 @@ int BNGradCPUKernel::Execute(int task_id) {
|
|||
auto *input_scale = in_tensors_.at(2);
|
||||
auto *input_mean = in_tensors_.at(3);
|
||||
auto *input_var = in_tensors_.at(4);
|
||||
|
||||
auto kernel_name = this->name();
|
||||
if (kernel_name.find("FusedBatchNormGradCPU") != std::string::npos) {
|
||||
input_mean = in_tensors_.at(4);
|
||||
input_var = in_tensors_.at(5);
|
||||
}
|
||||
auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_);
|
||||
int stage = stage_;
|
||||
int thread_num = thread_num_;
|
||||
|
|
|
@ -64,11 +64,11 @@ class TransferSession : public lite::TrainSession {
|
|||
int CompileTransferGraph();
|
||||
|
||||
protected:
|
||||
lite::LiteSession *backbone_session_;
|
||||
char *lite_model_;
|
||||
lite::LiteSession *backbone_session_ = nullptr;
|
||||
char *lite_model_ = nullptr;
|
||||
std::vector<mindspore::tensor::MSTensor *> combined_inputs_;
|
||||
std::vector<std::pair<mindspore::tensor::MSTensor *, mindspore::tensor::MSTensor *>> backbone_head_map_;
|
||||
bool is_valid_;
|
||||
bool is_valid_ = false;
|
||||
|
||||
private:
|
||||
bool CompileFormatTransform(tensor::MSTensor *out, tensor::MSTensor *in, int *mask);
|
||||
|
|
Loading…
Reference in New Issue