!44053 fix ger mac性能问题
Merge pull request !44053 from KXiong/branch_mac_ger
This commit is contained in:
commit
b04de5de6e
|
@ -106,6 +106,20 @@ int GerCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vec
|
|||
in2dim_ = input_shape_2_[input_shape_2_.size() - 1];
|
||||
outdim_ = in1dim_ * in2dim_;
|
||||
|
||||
#ifdef __APPLE__
|
||||
if (input_type_1_ == kNumberTypeFloat64 && batches_ != kNoBatchNum) {
|
||||
launch_func_ = &GerCpuKernelMod::LaunchMacBatches<double>;
|
||||
} else if (input_type_1_ == kNumberTypeFloat64) {
|
||||
launch_func_ = &GerCpuKernelMod::LaunchMacNoBatches<double>;
|
||||
} else if (input_type_1_ == kNumberTypeFloat32 && batches_ != kNoBatchNum) {
|
||||
launch_func_ = &GerCpuKernelMod::LaunchMacBatches<float>;
|
||||
} else if (input_type_1_ == kNumberTypeFloat32) {
|
||||
launch_func_ = &GerCpuKernelMod::LaunchMacNoBatches<float>;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Ger kernel does not support " << TypeIdToString(input_type_1_);
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
#else
|
||||
if (input_type_1_ == kNumberTypeFloat64) {
|
||||
InitLaunchFunc<double>();
|
||||
} else if (input_type_1_ == kNumberTypeFloat32) {
|
||||
|
@ -118,7 +132,7 @@ int GerCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vec
|
|||
MS_LOG(ERROR) << "Ger kernel does not support " << TypeIdToString(input_type_1_);
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
#endif
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
|
@ -228,6 +242,51 @@ bool GerCpuKernelMod::LaunchNoBatches(const std::vector<kernel::AddressPtr> &inp
|
|||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool GerCpuKernelMod::LaunchMacBatches(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *input1 = reinterpret_cast<T *>(inputs[kIndex0]->addr);
|
||||
T *input2 = reinterpret_cast<T *>(inputs[kIndex1]->addr);
|
||||
T *output = reinterpret_cast<T *>(outputs[kIndex0]->addr);
|
||||
auto task = [this, &input1, &input2, &output](size_t start, size_t end) {
|
||||
for (size_t batch_index = 0; batch_index < end; batch_index++) {
|
||||
size_t row_i_s = batch_index * this->in1dim_;
|
||||
size_t col_i_s = batch_index * this->in2dim_;
|
||||
size_t out_i_s = batch_index * this->in1dim_ * this->in2dim_;
|
||||
for (size_t row_i = 0; row_i < this->in1dim_; row_i++) {
|
||||
T in_one = input1[row_i_s + row_i];
|
||||
size_t out_i_i_s = out_i_s + row_i * this->in2dim_;
|
||||
for (size_t col_i = 0; col_i < this->in2dim_; col_i++) {
|
||||
output[out_i_i_s + col_i] = in_one * input2[col_i_s + col_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, batches_, this, ¶llel_search_info_);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool GerCpuKernelMod::LaunchMacNoBatches(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *input1 = reinterpret_cast<T *>(inputs[kIndex0]->addr);
|
||||
T *input2 = reinterpret_cast<T *>(inputs[kIndex1]->addr);
|
||||
T *output = reinterpret_cast<T *>(outputs[kIndex0]->addr);
|
||||
auto task = [this, &input1, &input2, &output](size_t start, size_t end) {
|
||||
for (size_t row = start; row < end; row++) {
|
||||
T in_one = input1[row];
|
||||
size_t row_i_s = row * this->in2dim_;
|
||||
for (size_t col = 0; col < this->in2dim_; col++) {
|
||||
output[row_i_s + col] = in_one * input2[col];
|
||||
}
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, in1dim_, this, ¶llel_search_info_);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool GerCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
|
|
|
@ -77,6 +77,13 @@ class GerCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<GerC
|
|||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
bool LaunchNoBatches(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
bool LaunchMacBatches(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
bool LaunchMacNoBatches(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
std::string kernel_type_{"Unknown"};
|
||||
TypeId input_type_1_{kTypeUnknown};
|
||||
|
|
Loading…
Reference in New Issue