!46564 fix fp16 precision problem of Scatter operators on CPU.

Merge pull request !46564 from yangshuo/fix_scattermul
This commit is contained in:
i-robot 2022-12-09 01:40:46 +00:00 committed by Gitee
commit 2c3253df58
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 10 additions and 10 deletions

View File

@ -68,14 +68,14 @@ template <typename T, typename S>
bool ScatterArithmeticCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
static const mindspore::HashMap<std::string, std::function<T(const T &a, const T &b)>> scatter_arithmetic_func_map{
{prim::kPrimScatterMul->name(), [](const T &a, const T &b) { return a * b; }},
{prim::kPrimScatterDiv->name(), [](const T &a, const T &b) { return a / b; }},
{prim::kPrimScatterAdd->name(), [](const T &a, const T &b) { return a + b; }},
{prim::kPrimScatterSub->name(), [](const T &a, const T &b) { return a - b; }},
{prim::kPrimScatterMax->name(), [](const T &a, const T &b) { return a > b ? a : b; }},
{prim::kPrimScatterMin->name(), [](const T &a, const T &b) { return a > b ? b : a; }},
{prim::kPrimScatterUpdate->name(), [](const T &a, const T &b) { return b; }},
static const mindspore::HashMap<std::string, std::function<void(T & a, const T &b)>> scatter_arithmetic_func_map{
{prim::kPrimScatterMul->name(), [](T &a, const T &b) { a *= b; }},
{prim::kPrimScatterDiv->name(), [](T &a, const T &b) { a /= b; }},
{prim::kPrimScatterAdd->name(), [](T &a, const T &b) { a += b; }},
{prim::kPrimScatterSub->name(), [](T &a, const T &b) { a -= b; }},
{prim::kPrimScatterMax->name(), [](T &a, const T &b) { a = a > b ? a : b; }},
{prim::kPrimScatterMin->name(), [](T &a, const T &b) { a = a > b ? b : a; }},
{prim::kPrimScatterUpdate->name(), [](T &a, const T &b) { a = b; }},
};
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kScatterArithmeticInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kScatterArithmeticOutputsNum, kernel_name_);
@ -100,7 +100,7 @@ bool ScatterArithmeticCpuKernelMod::LaunchKernel(const std::vector<kernel::Addre
if (std::equal_to<T>()(updates[base_index_updates + j], static_cast<T>(0))) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', updates must not contain 0";
}
input[base_index_input + j] = func_iter->second(input[base_index_input + j], updates[base_index_updates + j]);
func_iter->second(input[base_index_input + j], updates[base_index_updates + j]);
}
}
} else {
@ -112,7 +112,7 @@ bool ScatterArithmeticCpuKernelMod::LaunchKernel(const std::vector<kernel::Addre
<< "), but got '" << indices[i] << "' in indices.";
}
for (size_t j = 0; j < inner_size_; j++) {
input[base_index_input + j] = func_iter->second(input[base_index_input + j], updates[base_index_updates + j]);
func_iter->second(input[base_index_input + j], updates[base_index_updates + j]);
}
}
}