forked from mindspore-Ecosystem/mindspore
fix train issue
This commit is contained in:
parent
83d4c8dbe3
commit
dfd2f8f2c1
|
@ -81,7 +81,14 @@ int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size,
|
|||
ellipsis_mask_[i] = (bool)(param->ellipsisMask_) & (1 << i);
|
||||
new_axis_mask_[i] = (bool)(param->newAxisMask_) & (1 << i);
|
||||
}
|
||||
|
||||
param->num_axes_ = in_shape_size;
|
||||
param->in_shape_length_ = in_shape_size;
|
||||
for (int i = 0; i < ndim_; ++i) {
|
||||
param->begins_[i] = begins_[i];
|
||||
param->ends_[i] = ends_[i];
|
||||
param->strides_[i] = strides_[i];
|
||||
}
|
||||
ShapeSet(param->in_shape_, &in_shape_size, input->shape_, input->shape_size_);
|
||||
// ApplyNewAxisMask();
|
||||
for (size_t i = 0; i < ndim_; i++) {
|
||||
if (new_axis_mask_[i]) {
|
||||
|
|
|
@ -387,6 +387,10 @@ void TrainSession::CompileOptimizedKernels() {
|
|||
}
|
||||
|
||||
int TrainSession::SetLearningRate(float learning_rate) {
|
||||
if (learning_rate < 0.0f) {
|
||||
MS_LOG(ERROR) << "learning rate should more than 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (auto kernel : this->train_kernels_) {
|
||||
if (IsOptimizer(kernel)) {
|
||||
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel);
|
||||
|
|
Loading…
Reference in New Issue