forked from mindspore-Ecosystem/mindspore
!37829 TruncateMod and TruncateDiv Bugfixes
Merge pull request !37829 from jaspreetsinghsambee/truncate_ops_dev
This commit is contained in:
commit
452d612e69
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "plugin/device/cpu/kernel/truncate_div_cpu_kernel.h"
|
||||
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
@ -74,7 +75,22 @@ bool TruncateDivCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
|
|||
if (input_shape_1_ == input_shape_2_) {
|
||||
auto task = [output_addr, input_addr_a, input_addr_b](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output_addr[i] = static_cast<T>(input_addr_a[i] / input_addr_b[i]);
|
||||
auto dividend = input_addr_a[i];
|
||||
auto divisor = input_addr_b[i];
|
||||
auto zero = (T)0;
|
||||
if (divisor == zero) {
|
||||
if (dividend == zero) {
|
||||
output_addr[i] = std::numeric_limits<T>::quiet_NaN();
|
||||
continue;
|
||||
}
|
||||
if (std::numeric_limits<T>::has_infinity) {
|
||||
output_addr[i] = dividend > zero ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
|
||||
} else {
|
||||
output_addr[i] = dividend > zero ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
output_addr[i] = static_cast<T>(dividend / divisor);
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, output_size, this, ¶llel_search_info_);
|
||||
|
@ -84,7 +100,22 @@ bool TruncateDivCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
|
|||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output_addr[i] = static_cast<T>(input_addr_a[iter.GetInputPosA()] / input_addr_b[iter.GetInputPosB()]);
|
||||
auto dividend = input_addr_a[iter.GetInputPosA()];
|
||||
auto divisor = input_addr_b[iter.GetInputPosB()];
|
||||
auto zero = (T)0;
|
||||
if (divisor == zero) {
|
||||
if (dividend == zero) {
|
||||
output_addr[i] = std::numeric_limits<T>::quiet_NaN();
|
||||
continue;
|
||||
}
|
||||
if (std::numeric_limits<T>::has_infinity) {
|
||||
output_addr[i] = dividend > zero ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
|
||||
} else {
|
||||
output_addr[i] = dividend > zero ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
output_addr[i] = static_cast<T>(dividend / divisor);
|
||||
iter.GenNextPos();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "plugin/device/cpu/kernel/truncate_mod_cpu_kernel.h"
|
||||
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
@ -31,6 +32,25 @@ constexpr size_t kTruncateModInputsNum = 2;
|
|||
constexpr size_t kTruncateModOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
T GetTruncModDivZeroVal(const T &v) {
|
||||
auto zero = static_cast<T>(0.0);
|
||||
if (std::numeric_limits<T>::has_infinity) {
|
||||
return v > zero ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
|
||||
} else {
|
||||
return v > zero ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
|
||||
}
|
||||
}
|
||||
|
||||
float16 GetTruncModDivZeroVal(const float16 &v) {
|
||||
auto zero = static_cast<float16>(0.0);
|
||||
if (std::numeric_limits<float16>::has_infinity) {
|
||||
return v > zero ? std::numeric_limits<float16>::infinity() : -std::numeric_limits<float16>::infinity();
|
||||
} else {
|
||||
return v > zero ? std::numeric_limits<float16>::max() : std::numeric_limits<float16>::min();
|
||||
}
|
||||
}
|
||||
|
||||
bool TruncateModCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
|
@ -72,7 +92,18 @@ bool TruncateModCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
|
|||
if (input_shape_1_ == input_shape_2_) {
|
||||
auto task = [output_addr, input_addr_a, input_addr_b](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output_addr[i] = static_cast<T>(static_cast<int>(input_addr_a[i]) % static_cast<int>(input_addr_b[i]));
|
||||
auto dividend = input_addr_a[i];
|
||||
auto divisor = input_addr_b[i];
|
||||
auto zero = (T)0;
|
||||
if (divisor == zero) {
|
||||
if (dividend == zero) {
|
||||
output_addr[i] = std::numeric_limits<T>::quiet_NaN();
|
||||
continue;
|
||||
}
|
||||
output_addr[i] = GetTruncModDivZeroVal(dividend);
|
||||
continue;
|
||||
}
|
||||
output_addr[i] = static_cast<T>(dividend - static_cast<int>(dividend / divisor) * divisor);
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, output_size, this, ¶llel_search_info_);
|
||||
|
@ -82,8 +113,75 @@ bool TruncateModCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
|
|||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output_addr[i] = static_cast<T>(static_cast<int>(input_addr_a[iter.GetInputPosA()]) %
|
||||
static_cast<int>(input_addr_b[iter.GetInputPosB()]));
|
||||
auto dividend = input_addr_a[iter.GetInputPosA()];
|
||||
auto divisor = input_addr_b[iter.GetInputPosB()];
|
||||
auto zero = (T)0;
|
||||
if (divisor == zero) {
|
||||
if (dividend == zero) {
|
||||
output_addr[i] = std::numeric_limits<T>::quiet_NaN();
|
||||
continue;
|
||||
}
|
||||
output_addr[i] = GetTruncModDivZeroVal(dividend);
|
||||
continue;
|
||||
}
|
||||
output_addr[i] = static_cast<T>(dividend - static_cast<int>(dividend / divisor) * divisor);
|
||||
iter.GenNextPos();
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, output_size, this, ¶llel_search_info_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TruncateModCpuKernelMod::LaunchKernelHalf(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTruncateModInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTruncateModOutputsNum, kernel_name_);
|
||||
auto *input_addr_a = reinterpret_cast<float16 *>(inputs[kZero]->addr);
|
||||
auto *input_addr_b = reinterpret_cast<float16 *>(inputs[kOne]->addr);
|
||||
auto *output_addr = reinterpret_cast<float16 *>(outputs[kZero]->addr);
|
||||
size_t output_size = outputs[0]->size / sizeof(float16);
|
||||
if (input_shape_1_ == input_shape_2_) {
|
||||
auto task = [output_addr, input_addr_a, input_addr_b](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
auto dividend = input_addr_a[i];
|
||||
auto divisor = input_addr_b[i];
|
||||
auto zero = (float16)0;
|
||||
if (divisor == zero) {
|
||||
if (dividend == zero) {
|
||||
output_addr[i] = std::numeric_limits<float16>::quiet_NaN();
|
||||
continue;
|
||||
}
|
||||
output_addr[i] = GetTruncModDivZeroVal(dividend);
|
||||
continue;
|
||||
}
|
||||
output_addr[i] = static_cast<float16>(
|
||||
static_cast<float>(dividend) -
|
||||
static_cast<int>(static_cast<float>(dividend) / static_cast<float>(divisor)) * static_cast<float>(divisor));
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, output_size, this, ¶llel_search_info_);
|
||||
} else { // For Broadcast
|
||||
BroadcastIterator base_iter(input_shape_1_, input_shape_2_, output_shape_);
|
||||
auto task = [&base_iter, output_addr, input_addr_a, input_addr_b](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
auto dividend = input_addr_a[iter.GetInputPosA()];
|
||||
auto divisor = input_addr_b[iter.GetInputPosB()];
|
||||
auto zero = (float16)0;
|
||||
if (divisor == zero) {
|
||||
if (dividend == zero) {
|
||||
output_addr[i] = std::numeric_limits<float16>::quiet_NaN();
|
||||
continue;
|
||||
}
|
||||
output_addr[i] = GetTruncModDivZeroVal(dividend);
|
||||
continue;
|
||||
}
|
||||
output_addr[i] = static_cast<float16>(
|
||||
static_cast<float>(dividend) -
|
||||
static_cast<int>(static_cast<float>(dividend) / static_cast<float>(divisor)) * static_cast<float>(divisor));
|
||||
iter.GenNextPos();
|
||||
}
|
||||
};
|
||||
|
@ -114,7 +212,7 @@ std::vector<std::pair<KernelAttr, TruncateModCpuKernelMod::TruncateModFunc>> Tru
|
|||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&TruncateModCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&TruncateModCpuKernelMod::LaunchKernel<float16>}};
|
||||
&TruncateModCpuKernelMod::LaunchKernelHalf}};
|
||||
|
||||
std::vector<KernelAttr> TruncateModCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
|
|
@ -49,6 +49,8 @@ class TruncateModCpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
bool LaunchKernelHalf(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using TruncateModFunc =
|
||||
std::function<bool(TruncateModCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
|
|
@ -57,7 +57,7 @@ TypePtr TruncateModInferType(const PrimitivePtr &prim, const std::vector<Abstrac
|
|||
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', input must be a tensor, but got: " << z_type->ToString()
|
||||
<< ".";
|
||||
}
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt8, kInt32, kUInt8};
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt32, kUInt8};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", z_type, valid_types, prim_name);
|
||||
return z_type;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue