fix train issue

This commit is contained in:
zhengjun10 2021-03-19 15:41:01 +08:00
parent 83d4c8dbe3
commit dfd2f8f2c1
2 changed files with 12 additions and 1 deletions

View File

@ -81,7 +81,14 @@ int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size,
ellipsis_mask_[i] = (bool)(param->ellipsisMask_) & (1 << i);
new_axis_mask_[i] = (bool)(param->newAxisMask_) & (1 << i);
}
param->num_axes_ = in_shape_size;
param->in_shape_length_ = in_shape_size;
for (int i = 0; i < ndim_; ++i) {
param->begins_[i] = begins_[i];
param->ends_[i] = ends_[i];
param->strides_[i] = strides_[i];
}
ShapeSet(param->in_shape_, &in_shape_size, input->shape_, input->shape_size_);
// ApplyNewAxisMask();
for (size_t i = 0; i < ndim_; i++) {
if (new_axis_mask_[i]) {

View File

@ -387,6 +387,10 @@ void TrainSession::CompileOptimizedKernels() {
}
int TrainSession::SetLearningRate(float learning_rate) {
if (learning_rate < 0.0f) {
MS_LOG(ERROR) << "learning rate should more than 0";
return RET_ERROR;
}
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel);