!44053 fix ger mac性能问题

Merge pull request !44053 from KXiong/branch_mac_ger
This commit is contained in:
i-robot 2022-10-18 01:49:33 +00:00 committed by Gitee
commit b04de5de6e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 67 additions and 1 deletions

View File

@ -106,6 +106,20 @@ int GerCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vec
in2dim_ = input_shape_2_[input_shape_2_.size() - 1];
outdim_ = in1dim_ * in2dim_;
#ifdef __APPLE__
if (input_type_1_ == kNumberTypeFloat64 && batches_ != kNoBatchNum) {
launch_func_ = &GerCpuKernelMod::LaunchMacBatches<double>;
} else if (input_type_1_ == kNumberTypeFloat64) {
launch_func_ = &GerCpuKernelMod::LaunchMacNoBatches<double>;
} else if (input_type_1_ == kNumberTypeFloat32 && batches_ != kNoBatchNum) {
launch_func_ = &GerCpuKernelMod::LaunchMacBatches<float>;
} else if (input_type_1_ == kNumberTypeFloat32) {
launch_func_ = &GerCpuKernelMod::LaunchMacNoBatches<float>;
} else {
MS_LOG(ERROR) << "Ger kernel does not support " << TypeIdToString(input_type_1_);
return KRET_RESIZE_FAILED;
}
#else
if (input_type_1_ == kNumberTypeFloat64) {
InitLaunchFunc<double>();
} else if (input_type_1_ == kNumberTypeFloat32) {
@ -118,7 +132,7 @@ int GerCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vec
MS_LOG(ERROR) << "Ger kernel does not support " << TypeIdToString(input_type_1_);
return KRET_RESIZE_FAILED;
}
#endif
return KRET_OK;
}
@ -228,6 +242,51 @@ bool GerCpuKernelMod::LaunchNoBatches(const std::vector<kernel::AddressPtr> &inp
return true;
}
template <typename T>
bool GerCpuKernelMod::LaunchMacBatches(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
T *input1 = reinterpret_cast<T *>(inputs[kIndex0]->addr);
T *input2 = reinterpret_cast<T *>(inputs[kIndex1]->addr);
T *output = reinterpret_cast<T *>(outputs[kIndex0]->addr);
auto task = [this, &input1, &input2, &output](size_t start, size_t end) {
for (size_t batch_index = 0; batch_index < end; batch_index++) {
size_t row_i_s = batch_index * this->in1dim_;
size_t col_i_s = batch_index * this->in2dim_;
size_t out_i_s = batch_index * this->in1dim_ * this->in2dim_;
for (size_t row_i = 0; row_i < this->in1dim_; row_i++) {
T in_one = input1[row_i_s + row_i];
size_t out_i_i_s = out_i_s + row_i * this->in2dim_;
for (size_t col_i = 0; col_i < this->in2dim_; col_i++) {
output[out_i_i_s + col_i] = in_one * input2[col_i_s + col_i];
}
}
}
};
ParallelLaunchAutoSearch(task, batches_, this, &parallel_search_info_);
return true;
}
template <typename T>
bool GerCpuKernelMod::LaunchMacNoBatches(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
T *input1 = reinterpret_cast<T *>(inputs[kIndex0]->addr);
T *input2 = reinterpret_cast<T *>(inputs[kIndex1]->addr);
T *output = reinterpret_cast<T *>(outputs[kIndex0]->addr);
auto task = [this, &input1, &input2, &output](size_t start, size_t end) {
for (size_t row = start; row < end; row++) {
T in_one = input1[row];
size_t row_i_s = row * this->in2dim_;
for (size_t col = 0; col < this->in2dim_; col++) {
output[row_i_s + col] = in_one * input2[col];
}
}
};
ParallelLaunchAutoSearch(task, in1dim_, this, &parallel_search_info_);
return true;
}
template <typename T>
bool GerCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,

View File

@ -77,6 +77,13 @@ class GerCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<GerC
const std::vector<kernel::AddressPtr> &outputs);
bool LaunchNoBatches(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
template <typename T>
bool LaunchMacBatches(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
template <typename T>
bool LaunchMacNoBatches(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
std::string kernel_type_{"Unknown"};
TypeId input_type_1_{kTypeUnknown};