add memcpy_s check

This commit is contained in:
zhaoting 2020-09-17 18:03:46 +08:00
parent defd74e261
commit 2d315b32e6
2 changed files with 14 additions and 5 deletions

View File

@ -53,7 +53,11 @@ bool AssignAddCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
memcpy_s(inputs[0]->addr, inputs[0]->size, outputs[0]->addr, outputs[0]->size);
auto ret = memcpy_s(inputs[0]->addr, inputs[0]->size, outputs[0]->addr, outputs[0]->size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "Memcpy_s error, errorno " << ret;
return false;
}
return true;
}
} // namespace kernel

View File

@ -40,7 +40,7 @@ void FusedBatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (x_shape.size() != 4) {
MS_LOG(EXCEPTION) << "fused batchnorm only support nchw input!";
MS_LOG(EXCEPTION) << "Fused batchnorm only support nchw input!";
}
batch_size = x_shape[0];
channel = x_shape[1];
@ -71,11 +71,16 @@ bool FusedBatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() < 5 || outputs.empty()) {
MS_LOG(EXCEPTION) << "error input output size!";
MS_LOG(EXCEPTION) << "Error input output size!";
}
auto wksp = reinterpret_cast<float *>(workspace[0]->addr);
memcpy_s(wksp, workspace[0]->size, inputs[1]->addr, inputs[1]->size);
memcpy_s(wksp + (inputs[1]->size / sizeof(float)), inputs[2]->size, inputs[2]->addr, inputs[2]->size);
auto scale_ret = memcpy_s(wksp, workspace[0]->size, inputs[1]->addr, inputs[1]->size);
auto max_size = workspace[0]->size - inputs[1]->size;
auto bias_ret = memcpy_s(wksp + (inputs[1]->size / sizeof(float)), max_size, inputs[2]->addr, inputs[2]->size);
if (scale_ret != 0 || bias_ret != 0) {
MS_LOG(EXCEPTION) << "Memcpy_s error.";
return false;
}
if (is_train) {
SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_MEAN, outputs[3]->addr);