!37829 TruncateMod and TruncateDiv Bugfixes

Merge pull request !37829 from jaspreetsinghsambee/truncate_ops_dev
This commit is contained in:
i-robot 2022-07-19 13:07:59 +00:00 committed by Gitee
commit 452d612e69
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 138 additions and 7 deletions

View File

@ -16,6 +16,7 @@
#include "plugin/device/cpu/kernel/truncate_div_cpu_kernel.h"
#include <limits>
#include <algorithm>
#include <functional>
#include <utility>
@ -74,7 +75,22 @@ bool TruncateDivCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
if (input_shape_1_ == input_shape_2_) {
auto task = [output_addr, input_addr_a, input_addr_b](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
output_addr[i] = static_cast<T>(input_addr_a[i] / input_addr_b[i]);
auto dividend = input_addr_a[i];
auto divisor = input_addr_b[i];
auto zero = (T)0;
if (divisor == zero) {
if (dividend == zero) {
output_addr[i] = std::numeric_limits<T>::quiet_NaN();
continue;
}
if (std::numeric_limits<T>::has_infinity) {
output_addr[i] = dividend > zero ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else {
output_addr[i] = dividend > zero ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
}
continue;
}
output_addr[i] = static_cast<T>(dividend / divisor);
}
};
ParallelLaunchAutoSearch(task, output_size, this, &parallel_search_info_);
@ -84,7 +100,22 @@ bool TruncateDivCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; ++i) {
output_addr[i] = static_cast<T>(input_addr_a[iter.GetInputPosA()] / input_addr_b[iter.GetInputPosB()]);
auto dividend = input_addr_a[iter.GetInputPosA()];
auto divisor = input_addr_b[iter.GetInputPosB()];
auto zero = (T)0;
if (divisor == zero) {
if (dividend == zero) {
output_addr[i] = std::numeric_limits<T>::quiet_NaN();
continue;
}
if (std::numeric_limits<T>::has_infinity) {
output_addr[i] = dividend > zero ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else {
output_addr[i] = dividend > zero ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
}
continue;
}
output_addr[i] = static_cast<T>(dividend / divisor);
iter.GenNextPos();
}
};

View File

@ -16,6 +16,7 @@
#include "plugin/device/cpu/kernel/truncate_mod_cpu_kernel.h"
#include <limits>
#include <algorithm>
#include <functional>
#include <utility>
@ -31,6 +32,25 @@ constexpr size_t kTruncateModInputsNum = 2;
constexpr size_t kTruncateModOutputsNum = 1;
} // namespace
template <typename T>
T GetTruncModDivZeroVal(const T &v) {
auto zero = static_cast<T>(0.0);
if (std::numeric_limits<T>::has_infinity) {
return v > zero ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else {
return v > zero ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
}
}
float16 GetTruncModDivZeroVal(const float16 &v) {
auto zero = static_cast<float16>(0.0);
if (std::numeric_limits<float16>::has_infinity) {
return v > zero ? std::numeric_limits<float16>::infinity() : -std::numeric_limits<float16>::infinity();
} else {
return v > zero ? std::numeric_limits<float16>::max() : std::numeric_limits<float16>::min();
}
}
bool TruncateModCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
@ -72,7 +92,18 @@ bool TruncateModCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
if (input_shape_1_ == input_shape_2_) {
auto task = [output_addr, input_addr_a, input_addr_b](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
output_addr[i] = static_cast<T>(static_cast<int>(input_addr_a[i]) % static_cast<int>(input_addr_b[i]));
auto dividend = input_addr_a[i];
auto divisor = input_addr_b[i];
auto zero = (T)0;
if (divisor == zero) {
if (dividend == zero) {
output_addr[i] = std::numeric_limits<T>::quiet_NaN();
continue;
}
output_addr[i] = GetTruncModDivZeroVal(dividend);
continue;
}
output_addr[i] = static_cast<T>(dividend - static_cast<int>(dividend / divisor) * divisor);
}
};
ParallelLaunchAutoSearch(task, output_size, this, &parallel_search_info_);
@ -82,8 +113,75 @@ bool TruncateModCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; ++i) {
output_addr[i] = static_cast<T>(static_cast<int>(input_addr_a[iter.GetInputPosA()]) %
static_cast<int>(input_addr_b[iter.GetInputPosB()]));
auto dividend = input_addr_a[iter.GetInputPosA()];
auto divisor = input_addr_b[iter.GetInputPosB()];
auto zero = (T)0;
if (divisor == zero) {
if (dividend == zero) {
output_addr[i] = std::numeric_limits<T>::quiet_NaN();
continue;
}
output_addr[i] = GetTruncModDivZeroVal(dividend);
continue;
}
output_addr[i] = static_cast<T>(dividend - static_cast<int>(dividend / divisor) * divisor);
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(task, output_size, this, &parallel_search_info_);
}
return true;
}
bool TruncateModCpuKernelMod::LaunchKernelHalf(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTruncateModInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTruncateModOutputsNum, kernel_name_);
auto *input_addr_a = reinterpret_cast<float16 *>(inputs[kZero]->addr);
auto *input_addr_b = reinterpret_cast<float16 *>(inputs[kOne]->addr);
auto *output_addr = reinterpret_cast<float16 *>(outputs[kZero]->addr);
size_t output_size = outputs[0]->size / sizeof(float16);
if (input_shape_1_ == input_shape_2_) {
auto task = [output_addr, input_addr_a, input_addr_b](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
auto dividend = input_addr_a[i];
auto divisor = input_addr_b[i];
auto zero = (float16)0;
if (divisor == zero) {
if (dividend == zero) {
output_addr[i] = std::numeric_limits<float16>::quiet_NaN();
continue;
}
output_addr[i] = GetTruncModDivZeroVal(dividend);
continue;
}
output_addr[i] = static_cast<float16>(
static_cast<float>(dividend) -
static_cast<int>(static_cast<float>(dividend) / static_cast<float>(divisor)) * static_cast<float>(divisor));
}
};
ParallelLaunchAutoSearch(task, output_size, this, &parallel_search_info_);
} else { // For Broadcast
BroadcastIterator base_iter(input_shape_1_, input_shape_2_, output_shape_);
auto task = [&base_iter, output_addr, input_addr_a, input_addr_b](size_t start, size_t end) {
auto iter = base_iter;
iter.SetPos(start);
for (size_t i = start; i < end; ++i) {
auto dividend = input_addr_a[iter.GetInputPosA()];
auto divisor = input_addr_b[iter.GetInputPosB()];
auto zero = (float16)0;
if (divisor == zero) {
if (dividend == zero) {
output_addr[i] = std::numeric_limits<float16>::quiet_NaN();
continue;
}
output_addr[i] = GetTruncModDivZeroVal(dividend);
continue;
}
output_addr[i] = static_cast<float16>(
static_cast<float>(dividend) -
static_cast<int>(static_cast<float>(dividend) / static_cast<float>(divisor)) * static_cast<float>(divisor));
iter.GenNextPos();
}
};
@ -114,7 +212,7 @@ std::vector<std::pair<KernelAttr, TruncateModCpuKernelMod::TruncateModFunc>> Tru
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&TruncateModCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&TruncateModCpuKernelMod::LaunchKernel<float16>}};
&TruncateModCpuKernelMod::LaunchKernelHalf}};
std::vector<KernelAttr> TruncateModCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;

View File

@ -49,6 +49,8 @@ class TruncateModCpuKernelMod : public NativeCpuKernelMod {
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
bool LaunchKernelHalf(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using TruncateModFunc =
std::function<bool(TruncateModCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;

View File

@ -57,7 +57,7 @@ TypePtr TruncateModInferType(const PrimitivePtr &prim, const std::vector<Abstrac
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', input must be a tensor, but got: " << z_type->ToString()
<< ".";
}
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt8, kInt32, kUInt8};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt32, kUInt8};
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", z_type, valid_types, prim_name);
return z_type;
}