Merge pull request !31571 from kisnwang/clean_code
This commit is contained in:
i-robot 2022-03-21 08:31:38 +00:00 committed by Gitee
commit 4dd3a94d27
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 139 additions and 114 deletions

View File

@ -30,7 +30,7 @@ inline void *LoadLibrary(const char *name) {
}
inline void *GetMPIAdapterHandle() {
static void *handle = LoadLibrary("libmpi_adapter.so");
void *handle = LoadLibrary("libmpi_adapter.so");
return handle;
}

View File

@ -27,8 +27,18 @@
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kSizeFloat32 = sizeof(float);
constexpr size_t kAdamDeltaInputsNum = 9;
constexpr size_t kAdamDeltaOutputsNum = 1;
constexpr size_t kMIndex = 0;
constexpr size_t kVIndex = 1;
constexpr size_t kBeta1PowIndex = 2;
constexpr size_t kBeta2PowIndex = 3;
constexpr size_t kLRIndex = 4;
constexpr size_t kBeta1Index = 5;
constexpr size_t kBeta2Index = 6;
constexpr size_t kEpsIndex = 7;
constexpr size_t kGradIndex = 8;
} // namespace
template <typename T>
@ -61,9 +71,9 @@ void AdamDeltaCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
std::vector<size_t> delta_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
std::vector<size_t> m_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> v_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> grad_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 8);
std::vector<size_t> m_shape = AnfAlgo::GetInputDeviceShape(kernel_node, kMIndex);
std::vector<size_t> v_shape = AnfAlgo::GetInputDeviceShape(kernel_node, kVIndex);
std::vector<size_t> grad_shape = AnfAlgo::GetInputDeviceShape(kernel_node, kGradIndex);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
if (!IsSameShape(delta_shape, m_shape)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
@ -101,8 +111,9 @@ void AdamDeltaCpuKernelMod::CheckParams(const std::vector<kernel::AddressPtr> &i
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kAdamDeltaInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kAdamDeltaOutputsNum, kernel_name_);
size_t elem_size = elem_num_ * 4;
std::vector<size_t> expect_sizes = {elem_size, elem_size, 4, 4, 4, 4, 4, 4, elem_size};
size_t elem_size = elem_num_ * kSizeFloat32;
std::vector<size_t> expect_sizes = {elem_size, elem_size, kSizeFloat32, kSizeFloat32, kSizeFloat32,
kSizeFloat32, kSizeFloat32, kSizeFloat32, elem_size};
std::vector<std::string> input_names = {"m", "v", "beta1_power", "beta2_power", "lr",
"beta1", "beta2", "epsilon", "grad"};
for (size_t i = 0; i < kAdamDeltaInputsNum; ++i) {
@ -125,18 +136,18 @@ bool AdamDeltaCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CheckParams(inputs, outputs);
auto m = reinterpret_cast<float *>(inputs[0]->addr);
auto v = reinterpret_cast<float *>(inputs[1]->addr);
auto beta1_power = reinterpret_cast<float *>(inputs[2]->addr)[0];
auto m = reinterpret_cast<float *>(inputs[kMIndex]->addr);
auto v = reinterpret_cast<float *>(inputs[kVIndex]->addr);
auto beta1_power = reinterpret_cast<float *>(inputs[kBeta1PowIndex]->addr)[0];
if (beta1_power == 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'beta1_power' should not be 1.";
}
auto beta2_power = reinterpret_cast<float *>(inputs[3]->addr)[0];
auto lr = reinterpret_cast<float *>(inputs[4]->addr)[0];
auto beta1 = reinterpret_cast<float *>(inputs[5]->addr)[0];
auto beta2 = reinterpret_cast<float *>(inputs[6]->addr)[0];
auto epsilon = reinterpret_cast<float *>(inputs[7]->addr)[0];
auto grad = reinterpret_cast<float *>(inputs[8]->addr);
auto beta2_power = reinterpret_cast<float *>(inputs[kBeta2PowIndex]->addr)[0];
auto lr = reinterpret_cast<float *>(inputs[kLRIndex]->addr)[0];
auto beta1 = reinterpret_cast<float *>(inputs[kBeta1Index]->addr)[0];
auto beta2 = reinterpret_cast<float *>(inputs[kBeta2Index]->addr)[0];
auto epsilon = reinterpret_cast<float *>(inputs[kEpsIndex]->addr)[0];
auto grad = reinterpret_cast<float *>(inputs[kGradIndex]->addr);
auto delta = reinterpret_cast<float *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(m);
MS_EXCEPTION_IF_NULL(v);

View File

@ -29,12 +29,15 @@ constexpr size_t kEmbeddingLookupCommGradOutputsNum = 1;
void EmbeddingLookUpCommGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
split_num_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "split_num");
MS_LOG(INFO) << "split_num: " << split_num_;
auto split_num = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "split_num");
split_num_ = LongToSize(split_num);
MS_LOG(INFO) << "split_num: " << split_num;
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (split_num_ == 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'split_num' should be greater than 0, but got 0.";
if (split_num <= 0 || split_num_ == 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'split_num' should be greater than 0, but got "
<< split_num;
}
split_num_ = LongToSize(split_num);
if (input_shape.size() < 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of input should be at least 1-D, but got: " << input_shape.size() << "-D";
@ -69,10 +72,11 @@ bool EmbeddingLookUpCommGradCpuKernelMod::Launch(const std::vector<kernel::Addre
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset failed. Error no: " << ret;
}
const std::vector<int> &rank_group = {0, 1, 2, 3, 4, 5, 6, 7};
size_t input_split_lens = input_size / LongToSize(split_num_) / sizeof(float_t);
size_t output_split_lens = output_size / LongToSize(split_num_) / sizeof(float_t);
for (int64_t i = 0; i < split_num_; i++) {
MPIAllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group, input_split_lens);
size_t input_split_lens = input_size / split_num_ / sizeof(float_t);
size_t output_split_lens = output_size / split_num_ / sizeof(float_t);
for (size_t i = 0; i < split_num_; ++i) {
(void)MPIAllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group,
input_split_lens);
}
const uint64_t kUSecondInSecond = 1000000;
#if defined(_WIN32) || defined(_WIN64)

View File

@ -42,7 +42,7 @@ class EmbeddingLookUpCommGradCpuKernelMod : public NativeCpuKernelMod {
}
private:
int64_t split_num_;
size_t split_num_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -185,7 +185,7 @@ void FusedAdaFactorCpuKernelMod::FactorUpdate(float *update, const std::vector<A
CPUKernelUtils::ParallelFor(task, exp_avg_sq_col_elem_num, kBatchSize);
// calc update
task = [&](size_t start, size_t end) {
task = [&, this](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
size_t row_i = i % row_dim_size;
size_t col_i = i / row_dim_size % col_dim_size;
@ -229,8 +229,8 @@ void FusedAdaFactorCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inp
}
std::function<void(size_t, size_t)> task;
// update = grad * grad + eps[0]
task = [&](size_t start, size_t end) {
// calc update
task = [&, this](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
auto tmp = static_cast<float>(grad[i]) * global_norm_reciprocal_;
update[i] = tmp * tmp + epsilon[0];
@ -242,7 +242,7 @@ void FusedAdaFactorCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inp
FactorUpdate<T>(update, inputs, workspaces);
} else {
// no factor
task = [&](size_t start, size_t end) {
task = [&, this](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
auto tmp = static_cast<float>(exp_avg_sq[i]) * beta2t + update[i] * one_minus_beta2t;
tmp = std::max(tmp, kEps);
@ -263,7 +263,7 @@ void FusedAdaFactorCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inp
auto update_rms = CalcRMS(update, elem_num_);
auto update_rms_threshold = update_rms / clip_threshold;
auto update_coff = learning_rate / std::max(update_rms_threshold, 1.0f);
task = [&](size_t start, size_t end) {
task = [&, this](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
update[i] = update[i] * update_coff;
if (enable_first_moment_) {

View File

@ -21,21 +21,39 @@
namespace mindspore {
namespace kernel {
static constexpr size_t BATCH_SIZE = 10000;
static constexpr float MIN_GLOBAL_NORM = 1e-10;
namespace {
constexpr size_t kSizeFloat32 = sizeof(float);
constexpr size_t kSizeFloat16 = sizeof(float16);
constexpr size_t kScalarIndex = 0;
constexpr size_t kFusedCastAdamWeightDecayInputNum = 10;
constexpr size_t kFusedCastAdamWeightDecayOutputNum = 3;
constexpr size_t kBatchSize = 10000;
constexpr float kMinGlobalNorm = 1e-10;
constexpr size_t kVarIndex = 0;
constexpr size_t kMIndex = 1;
constexpr size_t kVIndex = 2;
constexpr size_t kLRIndex = 3;
constexpr size_t kBeta1Index = 4;
constexpr size_t kBeta2Index = 5;
constexpr size_t kEpsIndex = 6;
constexpr size_t kDecayIndex = 7;
constexpr size_t kGradIndex = 8;
constexpr size_t kGlobalNormIndex = 9;
} // namespace
void FusedCastAdamWeightDecayCpuKernelMod::LaunchFusedCastAdamFp32(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &) {
auto m = reinterpret_cast<float *>(inputs[M]->addr);
auto v = reinterpret_cast<float *>(inputs[V]->addr);
auto lr = reinterpret_cast<float *>(inputs[LR]->addr)[kScalarIndex];
auto beta1 = reinterpret_cast<float *>(inputs[BETA1]->addr)[kScalarIndex];
auto beta2 = reinterpret_cast<float *>(inputs[BETA2]->addr)[kScalarIndex];
auto epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr)[kScalarIndex];
auto decay = reinterpret_cast<float *>(inputs[DECAY]->addr)[kScalarIndex];
auto gradient16 = reinterpret_cast<float16 *>(inputs[GRAD]->addr);
auto var = reinterpret_cast<float *>(inputs[VAR]->addr);
auto global_norm = reinterpret_cast<float *>(inputs[GLOBAL_NORM]->addr)[kScalarIndex];
if (global_norm < MIN_GLOBAL_NORM) {
auto m = reinterpret_cast<float *>(inputs[kMIndex]->addr);
auto v = reinterpret_cast<float *>(inputs[kVIndex]->addr);
auto lr = reinterpret_cast<float *>(inputs[kLRIndex]->addr)[kScalarIndex];
auto beta1 = reinterpret_cast<float *>(inputs[kBeta1Index]->addr)[kScalarIndex];
auto beta2 = reinterpret_cast<float *>(inputs[kBeta2Index]->addr)[kScalarIndex];
auto epsilon = reinterpret_cast<float *>(inputs[kEpsIndex]->addr)[kScalarIndex];
auto decay = reinterpret_cast<float *>(inputs[kDecayIndex]->addr)[kScalarIndex];
auto gradient16 = reinterpret_cast<float16 *>(inputs[kGradIndex]->addr);
auto var = reinterpret_cast<float *>(inputs[kVarIndex]->addr);
auto global_norm = reinterpret_cast<float *>(inputs[kGlobalNormIndex]->addr)[kScalarIndex];
if (global_norm < kMinGlobalNorm) {
global_norm = 1.0f;
}
auto global_norm_reciprocal = 1.0f / global_norm;
@ -43,7 +61,7 @@ void FusedCastAdamWeightDecayCpuKernelMod::LaunchFusedCastAdamFp32(const std::ve
const auto beta2_minus = 1 - beta2;
// multithreading
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / kSizeFloat32) : 1;
size_t lens = inputs[kVarIndex]->size > 0 ? static_cast<size_t>(inputs[kVarIndex]->size / kSizeFloat32) : 1;
std::function<void(size_t, size_t)> task;
task = [&](size_t start, size_t end) {
@ -59,22 +77,22 @@ void FusedCastAdamWeightDecayCpuKernelMod::LaunchFusedCastAdamFp32(const std::ve
var[i] -= lr * update;
}
};
CPUKernelUtils::ParallelFor(task, lens, BATCH_SIZE);
CPUKernelUtils::ParallelFor(task, lens, kBatchSize);
}
void FusedCastAdamWeightDecayCpuKernelMod::LaunchFusedCastAdamFp16(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &) {
auto m = reinterpret_cast<float *>(inputs[M]->addr);
auto v = reinterpret_cast<float *>(inputs[V]->addr);
auto lr = reinterpret_cast<float *>(inputs[LR]->addr)[kScalarIndex];
auto beta1 = reinterpret_cast<float *>(inputs[BETA1]->addr)[kScalarIndex];
auto beta2 = reinterpret_cast<float *>(inputs[BETA2]->addr)[kScalarIndex];
auto epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr)[kScalarIndex];
auto decay = reinterpret_cast<float *>(inputs[DECAY]->addr)[kScalarIndex];
auto gradient16 = reinterpret_cast<float16 *>(inputs[GRAD]->addr);
auto var16 = reinterpret_cast<float16 *>(inputs[VAR]->addr);
auto global_norm = reinterpret_cast<float *>(inputs[GLOBAL_NORM]->addr)[kScalarIndex];
if (global_norm < MIN_GLOBAL_NORM) {
auto m = reinterpret_cast<float *>(inputs[kMIndex]->addr);
auto v = reinterpret_cast<float *>(inputs[kVIndex]->addr);
auto lr = reinterpret_cast<float *>(inputs[kLRIndex]->addr)[kScalarIndex];
auto beta1 = reinterpret_cast<float *>(inputs[kBeta1Index]->addr)[kScalarIndex];
auto beta2 = reinterpret_cast<float *>(inputs[kBeta2Index]->addr)[kScalarIndex];
auto epsilon = reinterpret_cast<float *>(inputs[kEpsIndex]->addr)[kScalarIndex];
auto decay = reinterpret_cast<float *>(inputs[kDecayIndex]->addr)[kScalarIndex];
auto gradient16 = reinterpret_cast<float16 *>(inputs[kGradIndex]->addr);
auto var16 = reinterpret_cast<float16 *>(inputs[kVarIndex]->addr);
auto global_norm = reinterpret_cast<float *>(inputs[kGlobalNormIndex]->addr)[kScalarIndex];
if (global_norm < kMinGlobalNorm) {
global_norm = 1.0f;
}
auto global_norm_reciprocal = 1.0f / global_norm;
@ -82,7 +100,7 @@ void FusedCastAdamWeightDecayCpuKernelMod::LaunchFusedCastAdamFp16(const std::ve
const auto beta2_minus = 1 - beta2;
// multithreading
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / kSizeFloat16) : 1;
size_t lens = inputs[kVarIndex]->size > 0 ? static_cast<size_t>(inputs[kVarIndex]->size / kSizeFloat16) : 1;
std::function<void(size_t, size_t)> task;
task = [&](size_t start, size_t end) {
@ -100,15 +118,15 @@ void FusedCastAdamWeightDecayCpuKernelMod::LaunchFusedCastAdamFp16(const std::ve
var16[i] = static_cast<float16>(temp_var);
}
};
CPUKernelUtils::ParallelFor(task, lens, BATCH_SIZE);
CPUKernelUtils::ParallelFor(task, lens, kBatchSize);
}
void FusedCastAdamWeightDecayCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
std::vector<size_t> var_shape = AnfAlgo::GetInputDeviceShape(kernel_node, VAR);
var_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, VAR);
gradient_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, GRAD);
std::vector<size_t> var_shape = AnfAlgo::GetInputDeviceShape(kernel_node, kVarIndex);
var_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kVarIndex);
gradient_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kGradIndex);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != kFusedCastAdamWeightDecayInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be "
@ -149,41 +167,41 @@ void FusedCastAdamWeightDecayCpuKernelMod::CheckParam(const std::vector<kernel::
size_t elem_size_fp32 = elem_num_ * kSizeFloat32;
size_t elem_size_fp16 = elem_num_ * kSizeFloat16;
size_t var_size = var_dtype_ == kNumberTypeFloat16 ? elem_size_fp16 : elem_size_fp32;
if (inputs[VAR]->size != var_size) {
if (inputs[kVarIndex]->size != var_size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'var' should be " << var_size
<< ", but got " << inputs[VAR]->size;
<< ", but got " << inputs[kVarIndex]->size;
}
if (inputs[M]->size != elem_size_fp32) {
if (inputs[kMIndex]->size != elem_size_fp32) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'm' should be " << elem_size_fp32
<< ", but got " << inputs[M]->size;
<< ", but got " << inputs[kMIndex]->size;
}
if (inputs[V]->size != elem_size_fp32) {
if (inputs[kVIndex]->size != elem_size_fp32) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'v' should be " << elem_size_fp32
<< ", but got " << inputs[V]->size;
<< ", but got " << inputs[kVIndex]->size;
}
if (inputs[GRAD]->size != elem_size_fp16) {
if (inputs[kGradIndex]->size != elem_size_fp16) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the address size of 'gradient' should be " << elem_size_fp16
<< ", but got " << inputs[GRAD]->size;
<< ", but got " << inputs[kGradIndex]->size;
}
if (inputs[LR]->size != kSizeFloat32) {
if (inputs[kLRIndex]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the type of 'lr' should be float32, but got 'lr': " << inputs[LR];
<< "', the type of 'lr' should be float32, but got 'lr': " << inputs[kLRIndex];
}
if (inputs[BETA1]->size != kSizeFloat32) {
if (inputs[kBeta1Index]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the type of 'beta1' should be float32, but got 'beta1': " << inputs[BETA1];
<< "', the type of 'beta1' should be float32, but got 'beta1': " << inputs[kBeta1Index];
}
if (inputs[BETA2]->size != kSizeFloat32) {
if (inputs[kBeta2Index]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the type of 'beta2' should be float32, but got 'beta2': " << inputs[BETA2];
<< "', the type of 'beta2' should be float32, but got 'beta2': " << inputs[kBeta2Index];
}
if (inputs[EPSILON]->size != kSizeFloat32) {
if (inputs[kEpsIndex]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the type of 'epsilon' should be float32, but got 'epsilon': " << inputs[EPSILON];
<< "', the type of 'epsilon' should be float32, but got 'epsilon': " << inputs[kEpsIndex];
}
if (inputs[DECAY]->size != kSizeFloat32) {
if (inputs[kDecayIndex]->size != kSizeFloat32) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the type of 'decay' should be float32, but got 'decay': " << inputs[DECAY];
<< "', the type of 'decay' should be float32, but got 'decay': " << inputs[kDecayIndex];
}
}

View File

@ -23,12 +23,6 @@
namespace mindspore {
namespace kernel {
constexpr size_t kSizeFloat32 = sizeof(float);
constexpr size_t kSizeFloat16 = sizeof(float16);
constexpr size_t kScalarIndex = 0;
constexpr size_t kFusedCastAdamWeightDecayInputNum = 10;
constexpr size_t kFusedCastAdamWeightDecayOutputNum = 3;
class FusedCastAdamWeightDecayCpuKernelMod : public NativeCpuKernelMod {
public:
FusedCastAdamWeightDecayCpuKernelMod() = default;
@ -79,7 +73,6 @@ class FusedCastAdamWeightDecayCpuKernelMod : public NativeCpuKernelMod {
size_t elem_num_{0};
TypeId var_dtype_{kTypeUnknown};
TypeId gradient_dtype_{kTypeUnknown};
enum input_list_ { VAR, M, V, LR, BETA1, BETA2, EPSILON, DECAY, GRAD, GLOBAL_NORM };
};
} // namespace kernel
} // namespace mindspore

View File

@ -19,11 +19,6 @@
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kUniqueWithPadInputsNum = 2;
constexpr size_t kUniqueWithPadOutputsNum = 2;
} // namespace
bool UniqueWithPadCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
@ -31,13 +26,13 @@ bool UniqueWithPadCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &in
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kUniqueWithPadOutputsNum, kernel_name_);
if (dtype_ == kNumberTypeInt32) {
UniqueCpuKernelMod::LaunchKernel<int, int>(inputs, workspace, outputs);
PadOutput<int>(inputs, outputs);
PadOutput<int>(inputs, outputs, output_size_);
} else if (dtype_ == kNumberTypeInt64) {
UniqueCpuKernelMod::LaunchKernel<int64_t, int64_t>(inputs, workspace, outputs);
PadOutput<int64_t>(inputs, outputs);
PadOutput<int64_t>(inputs, outputs, output_size_);
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16) {
UniqueCpuKernelMod::LaunchKernel<float, int>(inputs, workspace, outputs);
PadOutput<float>(inputs, outputs);
PadOutput<float>(inputs, outputs, output_size_);
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dtype of input should be float16, float32, int32, or int64, but got "
@ -46,16 +41,6 @@ bool UniqueWithPadCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &in
return true;
}
template <typename T>
void UniqueWithPadCpuKernelMod::PadOutput(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) const {
auto pad_num = *reinterpret_cast<T *>(inputs[1]->addr);
auto *out = reinterpret_cast<T *>(outputs[0]->addr);
for (size_t i = output_size_; i < input_size_; ++i) {
out[i] = pad_num;
}
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, UniqueWithPad, UniqueWithPadCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -54,8 +54,21 @@ class UniqueWithPadCpuKernelMod : public UniqueCpuKernelMod {
}
private:
inline static constexpr size_t kUniqueWithPadInputsNum = 2;
inline static constexpr size_t kUniqueWithPadOutputsNum = 2;
template <typename T>
void PadOutput(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) const;
static void PadOutput(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs, size_t start) {
if (inputs.size() < kUniqueWithPadInputsNum || outputs.size() < kUniqueWithPadOutputsNum) {
return;
}
auto pad_num = *reinterpret_cast<T *>(inputs[1]->addr);
auto *out = reinterpret_cast<T *>(outputs[0]->addr);
size_t length = outputs[0]->size / sizeof(T);
for (size_t i = start; i < length; ++i) {
out[i] = pad_num;
}
}
};
} // namespace kernel
} // namespace mindspore

View File

@ -36,20 +36,21 @@ class UniqueWithPadCpuKernelTest : public UT::Common {
outputs_.clear();
}
AddressPtr CreateKernelAddress(void *addr) {
AddressPtr CreateKernelAddress(void *addr, size_t size) {
auto kernel_addr = std::make_shared<Address>();
kernel_addr->addr = addr;
kernel_addr->size = size;
return kernel_addr;
}
void CreateAddress() {
inputs_.push_back(CreateKernelAddress(x_.data()));
inputs_.push_back(CreateKernelAddress(&pad_dim_));
outputs_.push_back(CreateKernelAddress(out_.data()));
outputs_.push_back(CreateKernelAddress(idx_.data()));
workspace_.push_back(CreateKernelAddress(workspace_idx_.data()));
workspace_.push_back(CreateKernelAddress(workspace_idx_.data()));
workspace_.push_back(CreateKernelAddress(workspace_idx_.data()));
void CreateAddress(size_t type_size) {
inputs_.push_back(CreateKernelAddress(x_.data(), x_.size() * type_size));
inputs_.push_back(CreateKernelAddress(&pad_dim_, type_size));
outputs_.push_back(CreateKernelAddress(out_.data(), out_.size() * type_size));
outputs_.push_back(CreateKernelAddress(idx_.data(), idx_.size() * type_size));
workspace_.push_back(CreateKernelAddress(workspace_idx_.data(), workspace_idx_.size() * type_size));
workspace_.push_back(CreateKernelAddress(workspace_idx_.data(), workspace_idx_.size() * type_size));
workspace_.push_back(CreateKernelAddress(workspace_idx_.data(), workspace_idx_.size() * type_size));
}
std::vector<int64_t> x_;
@ -69,7 +70,7 @@ TEST_F(UniqueWithPadCpuKernelTest, compute_test) {
out_ = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
idx_ = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
workspace_idx_ = {1, 1, 1, 1, 1, 1, 1, 1, 1};
CreateAddress();
CreateAddress(sizeof(int64_t));
unique_with_pad_->Launch(inputs_, workspace_, outputs_);
// check compute result