forked from mindspore-Ecosystem/mindspore
cpu unique kernel support fp16
This commit is contained in:
parent
91a21f9477
commit
9c8ada03fe
|
@ -57,6 +57,7 @@ const char GROUP[] = "group";
|
|||
const char START[] = "start";
|
||||
const char LIMIT[] = "limit";
|
||||
const char DELTA[] = "delta";
|
||||
const char SORTED[] = "sorted";
|
||||
|
||||
enum OperateType {
|
||||
ADD = 0,
|
||||
|
|
|
@ -19,13 +19,16 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
const size_t kUseBucketUniqueSize = 100000;
|
||||
constexpr size_t kBucketSortThreshold = 100000;
|
||||
void UniqueCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
node_wpt_ = kernel_node;
|
||||
CheckParam(kernel_node);
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
input_size_ = input_shape[0];
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
if (AnfAlgo::HasNodeAttr(SORTED, kernel_node)) {
|
||||
sorted_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, SORTED);
|
||||
}
|
||||
}
|
||||
|
||||
void UniqueCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
||||
|
@ -41,9 +44,11 @@ bool UniqueCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
if (dtype_ == kNumberTypeInt32) {
|
||||
LaunchKernel<int, int>(inputs, workspace, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt64) {
|
||||
LaunchKernel<int64_t, int>(inputs, workspace, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<int64_t, int64_t>(inputs, workspace, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<float, int>(inputs, workspace, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Not support type: " << dtype_;
|
||||
}
|
||||
if (!node_wpt_.expired()) {
|
||||
auto node_ = node_wpt_.lock();
|
||||
|
@ -86,12 +91,18 @@ void UniqueCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const
|
|||
params->inverse_idx_ = reinterpret_cast<IndexType *>(outputs[1]->addr);
|
||||
params->input_size_ = input_size_;
|
||||
params->output_size_ = 0;
|
||||
params->need_sort_ = true;
|
||||
|
||||
params->thread_num_ = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
|
||||
if (input_size_ < kUseBucketUniqueSize) {
|
||||
Unique(params);
|
||||
if (sorted_) {
|
||||
params->need_sort_ = true;
|
||||
if (input_size_ < kBucketSortThreshold) {
|
||||
Unique(params);
|
||||
} else {
|
||||
BucketUnique(params);
|
||||
}
|
||||
} else {
|
||||
BucketUnique(params);
|
||||
params->need_sort_ = false;
|
||||
Unique(params);
|
||||
}
|
||||
output_size_ = params->output_size_;
|
||||
}
|
||||
|
|
|
@ -60,6 +60,7 @@ class UniqueCPUKernel : public CPUKernel {
|
|||
size_t input_size_{0};
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
size_t output_size_{0};
|
||||
bool sorted_{false};
|
||||
CNodeWeakPtr node_wpt_;
|
||||
|
||||
template <typename DataType>
|
||||
|
@ -378,7 +379,7 @@ MS_REG_CPU_KERNEL(
|
|||
UniqueCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Unique, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
Unique, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
UniqueCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
|
|
|
@ -26,11 +26,13 @@ bool UniqueWithPadCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
|
|||
UniqueCPUKernel::LaunchKernel<int, int>(inputs, workspace, outputs);
|
||||
PadOutput<int>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt64) {
|
||||
UniqueCPUKernel::LaunchKernel<int64_t, int>(inputs, workspace, outputs);
|
||||
UniqueCPUKernel::LaunchKernel<int64_t, int64_t>(inputs, workspace, outputs);
|
||||
PadOutput<int64_t>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16) {
|
||||
UniqueCPUKernel::LaunchKernel<float, int>(inputs, workspace, outputs);
|
||||
PadOutput<float>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Not support data type: " << dtype_;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -51,7 +51,7 @@ MS_REG_CPU_KERNEL(UniqueWithPad,
|
|||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
UniqueWithPadCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(UniqueWithPad,
|
||||
|
|
|
@ -46,6 +46,18 @@ def test_net_fp32():
|
|||
assert (output[0].asnumpy() == expect_y_result).all()
|
||||
assert (output[1].asnumpy() == expect_idx_result).all()
|
||||
|
||||
def test_net_fp16():
|
||||
x = Tensor(np.array([1, 5, 2, 2]), mstype.float16)
|
||||
uniq = Net()
|
||||
output = uniq(x)
|
||||
print("x:\n", x)
|
||||
print("y:\n", output[0])
|
||||
print("idx:\n", output[1])
|
||||
expect_y_result = [1., 5., 2.]
|
||||
expect_idx_result = [0, 1, 2, 2]
|
||||
|
||||
assert (output[0].asnumpy() == expect_y_result).all()
|
||||
assert (output[1].asnumpy() == expect_idx_result).all()
|
||||
|
||||
def test_net_int32():
|
||||
x = Tensor(np.array([1, 2, 5, 2]), mstype.int32)
|
||||
|
|
|
@ -55,7 +55,7 @@ class UniqueWithPadCpuKernelTest : public UT::Common {
|
|||
std::vector<int64_t> x_;
|
||||
int64_t pad_dim_;
|
||||
std::vector<int64_t> out_;
|
||||
std::vector<int> idx_;
|
||||
std::vector<int64_t> idx_;
|
||||
std::vector<int64_t> workspace_idx_;
|
||||
std::vector<AddressPtr> inputs_;
|
||||
std::vector<AddressPtr> workspace_;
|
||||
|
@ -73,8 +73,8 @@ TEST_F(UniqueWithPadCpuKernelTest, compute_test) {
|
|||
unique_with_pad_->Launch(inputs_, workspace_, outputs_);
|
||||
|
||||
// check compute result
|
||||
std::vector<int64_t> expect_out{1, 2, 3, 4, 5, 8, 8, 8, 8, 8};
|
||||
std::vector<int> expect_idx{0, 0, 4, 4, 3, 3, 2, 2, 1, 1};
|
||||
std::vector<int64_t> expect_out{1, 5, 4, 3, 2, 8, 8, 8, 8, 8};
|
||||
std::vector<int64_t> expect_idx{0, 0, 1, 1, 2, 2, 3, 3, 4, 4};
|
||||
EXPECT_TRUE(out_ == expect_out);
|
||||
EXPECT_TRUE(idx_ == expect_idx);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue