refactor lerp for cpu backend.

This commit is contained in:
z00512249 2022-05-06 07:32:12 +08:00
parent 0828efe02e
commit f33a93ff53
2 changed files with 60 additions and 17 deletions

View File

@ -13,28 +13,62 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/lerp_cpu_kernel.h"
#include <vector>
#include <algorithm>
#include <utility>
#include <memory>
#include <map>
namespace mindspore {
namespace kernel {
void LerpCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
bool LerpCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << kernel_name_ << " does not support this kernel data type: " << kernel_attr;
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
start_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex0);
end_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex1);
weight_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex2);
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, kIndex0);
for (const auto &out_shape : output_shape_) {
output_size_ *= out_shape;
return true;
}
int LerpCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
ResetResource();
int ret = KRET_OK;
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs)) != KRET_OK) {
return ret;
}
auto start_shape = inputs.at(kIndex0)->GetShapeVector();
auto end_shape = inputs.at(kIndex1)->GetShapeVector();
auto weight_shape = inputs.at(kIndex2)->GetShapeVector();
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
(void)std::transform(start_shape.begin(), start_shape.end(), std::back_inserter(start_shape_), LongToSize);
(void)std::transform(end_shape.begin(), end_shape.end(), std::back_inserter(end_shape_), LongToSize);
(void)std::transform(weight_shape.begin(), weight_shape.end(), std::back_inserter(weight_shape_), LongToSize);
(void)std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(output_shape_), LongToSize);
output_size_ = std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies<size_t>());
return ret;
}
void LerpCpuKernelMod::ResetResource() noexcept {
output_size_ = 0;
start_shape_.clear();
end_shape_.clear();
weight_shape_.clear();
output_shape_.clear();
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
template <typename T>
@ -53,7 +87,7 @@ bool LerpCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &input
output[i] = static_cast<T>(start_value + (end_value - start_value) * weight_value);
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_, pool_);
} else {
MultipleBroadcastIterator multi_broadcast_iterator({start_shape_, end_shape_, weight_shape_}, output_shape_);
auto *input_start = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
@ -72,21 +106,19 @@ bool LerpCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &input
iter.GenNextPos();
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_, pool_);
}
return true;
}
std::vector<std::pair<KernelAttr, LerpCpuKernelMod::LerpFunc>> LerpCpuKernelMod::func_list_ = {
{KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&LerpCpuKernelMod::LaunchKernel<float16>},
{KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)

View File

@ -16,23 +16,35 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LERP_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LERP_CPU_KERNEL_H_
#include <string>
#include <vector>
#include <utility>
#include <map>
#include <functional>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class LerpCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class LerpCpuKernelMod : public NativeCpuKernelMod {
public:
LerpCpuKernelMod() = default;
~LerpCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
void ResetResource() noexcept;
protected:
std::vector<KernelAttr> GetOpSupport() override;
@ -43,7 +55,6 @@ class LerpCpuKernelMod : public DeprecatedNativeCpuKernelMod {
const std::vector<kernel::AddressPtr> &)>;
size_t output_size_{1};
LerpFunc kernel_func_;
std::string kernel_type_;
std::vector<size_t> start_shape_;
std::vector<size_t> end_shape_;
std::vector<size_t> weight_shape_;