fix CPU BatchNorm bugs
This commit is contained in:
parent
8195488630
commit
7637fcbf55
|
@ -13,8 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/batch_norm_gard_cpu_kernel.h"
|
||||
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/batch_norm_grad_cpu_kernel.h"
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
@ -83,11 +82,13 @@ bool BatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
|
|||
}
|
||||
auto wksp_in = reinterpret_cast<float *>(workspace[0]->addr);
|
||||
auto scale_ret = memcpy_s(wksp_in, workspace[0]->size, inputs[2]->addr, inputs[2]->size);
|
||||
if (scale_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Scale memcpy error!";
|
||||
}
|
||||
auto max_size = workspace[0]->size - inputs[2]->size;
|
||||
auto bias_ret = memset_s(wksp_in + (inputs[2]->size / sizeof(float)), max_size, 0, max_size);
|
||||
if (scale_ret != 0 && bias_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Memcpy_s error.";
|
||||
return false;
|
||||
if (bias_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Bias memset 0 error.";
|
||||
}
|
||||
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr);
|
||||
|
@ -101,11 +102,13 @@ bool BatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
|
|||
|
||||
auto wksp_out = reinterpret_cast<float *>(workspace[1]->addr);
|
||||
auto diff_scale_ret = memcpy_s(outputs[1]->addr, outputs[1]->size, wksp_out, inputs[2]->size);
|
||||
if (diff_scale_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Diff_scale memcpy to output[1] error.";
|
||||
}
|
||||
auto diff_bias_ret =
|
||||
memcpy_s(outputs[2]->addr, outputs[2]->size, wksp_out + (outputs[1]->size / sizeof(float)), outputs[2]->size);
|
||||
if (diff_scale_ret != 0 || diff_bias_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Memcpy_s error.";
|
||||
return false;
|
||||
if (diff_bias_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Diff_bias memcpy to to output[2] error.";
|
||||
}
|
||||
return true;
|
||||
}
|
|
@ -32,7 +32,7 @@ void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
std::vector<size_t> kernel_size({weight_shape[2], weight_shape[3]});
|
||||
size_t group = LongToSize(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, GROUP));
|
||||
if (group != 1) {
|
||||
if (group > 1) {
|
||||
if (src_shape[1] % group != 0) {
|
||||
MS_LOG(EXCEPTION) << "Conv2d channels should be divided by group!";
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
std::vector<size_t> kernel_size({weight_shape[2], weight_shape[3]});
|
||||
size_t group = LongToSize(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, GROUP));
|
||||
if (group != 1) {
|
||||
if (group > 1) {
|
||||
if (src_shape[1] % group != 0) {
|
||||
MS_LOG(EXCEPTION) << "Conv2d channels should be divided by group!";
|
||||
}
|
||||
|
|
|
@ -42,7 +42,7 @@ void ConvCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
kernel_size.emplace_back(weight_shape[i]);
|
||||
}
|
||||
size_t group = LongToSize(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, GROUP));
|
||||
if (group != 1) {
|
||||
if (group > 1) {
|
||||
if (src_shape[1] % group != 0) {
|
||||
MS_LOG(EXCEPTION) << "Conv channels should be divided by group!";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue