fix cpu kernel:ScatterNdUpdate doesn't set output

This commit is contained in:
huanghui 2020-09-15 14:12:24 +08:00
parent aa16811ba5
commit d6944a70ca
5 changed files with 13 additions and 5 deletions

View File

@ -63,11 +63,11 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) {
bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> & /*outputs*/) {
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs);
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs);
LaunchKernel<float>(inputs, outputs);
} else {
MS_LOG(ERROR) << "Only support float16, float32";
return false;
@ -76,7 +76,8 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
}
template <typename T>
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs) {
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x = reinterpret_cast<T *>(inputs[0]->addr);
auto indices = reinterpret_cast<int *>(inputs[1]->addr);
auto updates = reinterpret_cast<T *>(inputs[2]->addr);
@ -100,6 +101,10 @@ void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &input
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
}
}
auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, mem_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
}
}
void ScatterNdUpdateCPUKernel::Check(const CNodePtr &kernel_node) {

View File

@ -35,7 +35,7 @@ class ScatterNdUpdateCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs);
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
private:
void Check(const CNodePtr &kernel_node);

View File

@ -15,6 +15,7 @@
*/
#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h"
#include <vector>
#include <algorithm>
#include "backend/optimizer/ascend/ir_fusion/input_to_output_registry.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/oplib/oplib.h"

View File

@ -17,6 +17,7 @@
#include "backend/optimizer/pass/add_atomic_clean.h"
#include <memory>
#include <vector>
#include <functional>
#include "base/core_ops.h"
#include "utils/utils.h"
#include "utils/log_adapter.h"

View File

@ -16,6 +16,7 @@
#include "backend/optimizer/pass/const_to_attr_strided_slice_grad.h"
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/ms_context.h"