!6247 Fix cpu ScatterNdUpdate doesn't update output
Merge pull request !6247 from huanghui/clear-warning
This commit is contained in:
commit
ff42cd87b2
|
@ -63,11 +63,11 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
|
||||
bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> & /*outputs*/) {
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<float16>(inputs);
|
||||
LaunchKernel<float16>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs);
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Only support float16, float32";
|
||||
return false;
|
||||
|
@ -76,7 +76,8 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs) {
|
||||
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto indices = reinterpret_cast<int *>(inputs[1]->addr);
|
||||
auto updates = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
|
@ -100,6 +101,10 @@ void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &input
|
|||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
||||
}
|
||||
}
|
||||
auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, mem_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
||||
}
|
||||
}
|
||||
|
||||
void ScatterNdUpdateCPUKernel::Check(const CNodePtr &kernel_node) {
|
||||
|
|
|
@ -35,7 +35,7 @@ class ScatterNdUpdateCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs);
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
void Check(const CNodePtr &kernel_node);
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h"
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "backend/optimizer/ascend/ir_fusion/input_to_output_registry.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "backend/optimizer/pass/add_atomic_clean.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "backend/optimizer/pass/const_to_attr_strided_slice_grad.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
|
Loading…
Reference in New Issue