forked from mindspore-Ecosystem/mindspore
add memcpy_s check
This commit is contained in:
parent
defd74e261
commit
2d315b32e6
|
@ -53,7 +53,11 @@ bool AssignAddCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
|
||||
ExecutePrimitive();
|
||||
memcpy_s(inputs[0]->addr, inputs[0]->size, outputs[0]->addr, outputs[0]->size);
|
||||
auto ret = memcpy_s(inputs[0]->addr, inputs[0]->size, outputs[0]->addr, outputs[0]->size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Memcpy_s error, errorno " << ret;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -40,7 +40,7 @@ void FusedBatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (x_shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "fused batchnorm only support nchw input!";
|
||||
MS_LOG(EXCEPTION) << "Fused batchnorm only support nchw input!";
|
||||
}
|
||||
batch_size = x_shape[0];
|
||||
channel = x_shape[1];
|
||||
|
@ -71,11 +71,16 @@ bool FusedBatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu
|
|||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() < 5 || outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "error input output size!";
|
||||
MS_LOG(EXCEPTION) << "Error input output size!";
|
||||
}
|
||||
auto wksp = reinterpret_cast<float *>(workspace[0]->addr);
|
||||
memcpy_s(wksp, workspace[0]->size, inputs[1]->addr, inputs[1]->size);
|
||||
memcpy_s(wksp + (inputs[1]->size / sizeof(float)), inputs[2]->size, inputs[2]->addr, inputs[2]->size);
|
||||
auto scale_ret = memcpy_s(wksp, workspace[0]->size, inputs[1]->addr, inputs[1]->size);
|
||||
auto max_size = workspace[0]->size - inputs[1]->size;
|
||||
auto bias_ret = memcpy_s(wksp + (inputs[1]->size / sizeof(float)), max_size, inputs[2]->addr, inputs[2]->size);
|
||||
if (scale_ret != 0 || bias_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Memcpy_s error.";
|
||||
return false;
|
||||
}
|
||||
if (is_train) {
|
||||
SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_MEAN, outputs[3]->addr);
|
||||
|
|
Loading…
Reference in New Issue