fix CPU BatchNorm bugs

This commit is contained in:
zhaoting 2021-07-01 11:25:10 +08:00
parent 8195488630
commit 7637fcbf55
5 changed files with 14 additions and 11 deletions

View File

@ -13,8 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/mkldnn/batch_norm_gard_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/batch_norm_grad_cpu_kernel.h"
#include <string>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
@ -83,11 +82,13 @@ bool BatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
}
auto wksp_in = reinterpret_cast<float *>(workspace[0]->addr);
auto scale_ret = memcpy_s(wksp_in, workspace[0]->size, inputs[2]->addr, inputs[2]->size);
if (scale_ret != 0) {
MS_LOG(EXCEPTION) << "Scale memcpy error!";
}
auto max_size = workspace[0]->size - inputs[2]->size;
auto bias_ret = memset_s(wksp_in + (inputs[2]->size / sizeof(float)), max_size, 0, max_size);
if (scale_ret != 0 && bias_ret != 0) {
MS_LOG(EXCEPTION) << "Memcpy_s error.";
return false;
if (bias_ret != 0) {
MS_LOG(EXCEPTION) << "Bias memset 0 error.";
}
SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr);
@ -101,11 +102,13 @@ bool BatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
auto wksp_out = reinterpret_cast<float *>(workspace[1]->addr);
auto diff_scale_ret = memcpy_s(outputs[1]->addr, outputs[1]->size, wksp_out, inputs[2]->size);
if (diff_scale_ret != 0) {
MS_LOG(EXCEPTION) << "Diff_scale memcpy to output[1] error.";
}
auto diff_bias_ret =
memcpy_s(outputs[2]->addr, outputs[2]->size, wksp_out + (outputs[1]->size / sizeof(float)), outputs[2]->size);
if (diff_scale_ret != 0 || diff_bias_ret != 0) {
MS_LOG(EXCEPTION) << "Memcpy_s error.";
return false;
if (diff_bias_ret != 0) {
MS_LOG(EXCEPTION) << "Diff_bias memcpy to to output[2] error.";
}
return true;
}

View File

@ -32,7 +32,7 @@ void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
std::vector<size_t> kernel_size({weight_shape[2], weight_shape[3]});
size_t group = LongToSize(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, GROUP));
if (group != 1) {
if (group > 1) {
if (src_shape[1] % group != 0) {
MS_LOG(EXCEPTION) << "Conv2d channels should be divided by group!";
}

View File

@ -34,7 +34,7 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
std::vector<size_t> kernel_size({weight_shape[2], weight_shape[3]});
size_t group = LongToSize(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, GROUP));
if (group != 1) {
if (group > 1) {
if (src_shape[1] % group != 0) {
MS_LOG(EXCEPTION) << "Conv2d channels should be divided by group!";
}

View File

@ -42,7 +42,7 @@ void ConvCPUKernel::InitKernel(const CNodePtr &kernel_node) {
kernel_size.emplace_back(weight_shape[i]);
}
size_t group = LongToSize(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, GROUP));
if (group != 1) {
if (group > 1) {
if (src_shape[1] % group != 0) {
MS_LOG(EXCEPTION) << "Conv channels should be divided by group!";
}