fix GatherV2 index out of bounds

This commit is contained in:
baihuawei 2020-07-06 09:39:30 +08:00
parent 5359d63eb0
commit b1d6ef0e88
1 changed files with 25 additions and 21 deletions

View File

@ -21,17 +21,14 @@ namespace mindspore {
namespace kernel { namespace kernel {
void GatherV2CPUKernel::InitKernel(const CNodePtr &kernel_node) { void GatherV2CPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node); CheckParam(kernel_node);
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
axis_ = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS); axis_ = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS);
if (axis_ < 0) { if (axis_ < 0) {
axis_ = axis_ + SizeToInt(input_shape_.size()); axis_ = axis_ + SizeToInt(input_shape_.size());
} }
axis_ += 4 - input_shape_.size(); axis_ += 4 - input_shape_.size();
CPUKernelUtils::ExpandDimsTo4(&input_shape_); CPUKernelUtils::ExpandDimsTo4(&input_shape_);
CPUKernelUtils::ExpandDimsTo4(&output_shape_); CPUKernelUtils::ExpandDimsTo4(&output_shape_);
} }
@ -44,7 +41,6 @@ bool GatherV2CPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
size_t dim0 = input_shape_[0]; size_t dim0 = input_shape_[0];
size_t dim1 = input_shape_[1]; size_t dim1 = input_shape_[1];
size_t dim2 = input_shape_[2]; size_t dim2 = input_shape_[2];
if (axis_ == 3) { if (axis_ == 3) {
for (size_t i = 0; i < dim0; ++i) { for (size_t i = 0; i < dim0; ++i) {
for (size_t j = 0; j < dim1; ++j) { for (size_t j = 0; j < dim1; ++j) {
@ -66,7 +62,6 @@ bool GatherV2CPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
} else if (axis_ == 0) { } else if (axis_ == 0) {
CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size); CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size);
} }
return true; return true;
} }
@ -75,34 +70,43 @@ void GatherV2CPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr); auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr);
size_t elem_num = inputs[1]->size / 4; size_t elem_num = inputs[1]->size / 4;
size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_);
for (size_t i = 0; i < elem_num; ++i) { for (size_t i = 0; i < elem_num; ++i) {
size_t index = IntToSize(indices_addr[i]); if (indices_addr[i] < 0) {
size_t pos = 0; MS_LOG(EXCEPTION) << "The indices value is less than 0.";
if (axis_ == 3) {
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, dim2, index);
} else if (axis_ == 2) {
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, index, 0);
} else if (axis_ == 1) {
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, index, 0, 0);
} else if (axis_ == 0) {
pos = CPUKernelUtils::CalcOffset(input_shape_, index, 0, 0, 0);
} }
size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); size_t index = IntToSize(indices_addr[i]);
auto ret = memcpy_s(*output_addr, *buff_size, input_addr + pos, num * sizeof(float)); if (index >= input_shape_[IntToSize(axis_)]) {
if (ret != EOK) { auto ret = memset_s(*output_addr, *buff_size, 0., num * sizeof(float));
MS_LOG(EXCEPTION) << "memcpy failed."; if (ret != EOK) {
MS_LOG(EXCEPTION) << "memset failed.";
}
} else {
size_t pos = 0;
if (axis_ == 3) {
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, dim2, index);
} else if (axis_ == 2) {
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, index, 0);
} else if (axis_ == 1) {
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, index, 0, 0);
} else if (axis_ == 0) {
pos = CPUKernelUtils::CalcOffset(input_shape_, index, 0, 0, 0);
}
auto ret = memcpy_s(*output_addr, *buff_size, input_addr + pos, num * sizeof(float));
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memcpy failed.";
}
} }
*output_addr += num; *output_addr += num;
*buff_size -= num * sizeof(float); *buff_size -= num * sizeof(float);
} }
} } // namespace kernel
void GatherV2CPUKernel::CheckParam(const CNodePtr &kernel_node) { void GatherV2CPUKernel::CheckParam(const CNodePtr &kernel_node) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() > 4) { if (input_shape.size() > 4) {
MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but GatherV2CPUKernel olny support 4d or lower."; MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but GatherV2CPUKernel olny support 4d or lower.";
} }
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) { if (input_num != 2) {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherV2CPUKernel needs 2."; MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherV2CPUKernel needs 2.";