forked from mindspore-Ecosystem/mindspore
refactor lerp for cpu backend.
This commit is contained in:
parent
0828efe02e
commit
f33a93ff53
|
@ -13,28 +13,62 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/lerp_cpu_kernel.h"
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void LerpCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
bool LerpCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << kernel_name_ << " does not support this kernel data type: " << kernel_attr;
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
start_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex0);
|
||||
end_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex1);
|
||||
weight_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex2);
|
||||
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, kIndex0);
|
||||
for (const auto &out_shape : output_shape_) {
|
||||
output_size_ *= out_shape;
|
||||
return true;
|
||||
}
|
||||
|
||||
int LerpCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
ResetResource();
|
||||
int ret = KRET_OK;
|
||||
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs)) != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
auto start_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
auto end_shape = inputs.at(kIndex1)->GetShapeVector();
|
||||
auto weight_shape = inputs.at(kIndex2)->GetShapeVector();
|
||||
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
(void)std::transform(start_shape.begin(), start_shape.end(), std::back_inserter(start_shape_), LongToSize);
|
||||
(void)std::transform(end_shape.begin(), end_shape.end(), std::back_inserter(end_shape_), LongToSize);
|
||||
(void)std::transform(weight_shape.begin(), weight_shape.end(), std::back_inserter(weight_shape_), LongToSize);
|
||||
(void)std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(output_shape_), LongToSize);
|
||||
output_size_ = std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies<size_t>());
|
||||
return ret;
|
||||
}
|
||||
|
||||
void LerpCpuKernelMod::ResetResource() noexcept {
|
||||
output_size_ = 0;
|
||||
start_shape_.clear();
|
||||
end_shape_.clear();
|
||||
weight_shape_.clear();
|
||||
output_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -53,7 +87,7 @@ bool LerpCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &input
|
|||
output[i] = static_cast<T>(start_value + (end_value - start_value) * weight_value);
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_, pool_);
|
||||
} else {
|
||||
MultipleBroadcastIterator multi_broadcast_iterator({start_shape_, end_shape_, weight_shape_}, output_shape_);
|
||||
auto *input_start = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
|
||||
|
@ -72,21 +106,19 @@ bool LerpCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &input
|
|||
iter.GenNextPos();
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_, pool_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, LerpCpuKernelMod::LerpFunc>> LerpCpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddAllSameAttr(true)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&LerpCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr()
|
||||
.AddAllSameAttr(true)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
|
|
@ -16,23 +16,35 @@
|
|||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LERP_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LERP_CPU_KERNEL_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class LerpCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class LerpCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
LerpCpuKernelMod() = default;
|
||||
~LerpCpuKernelMod() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
void ResetResource() noexcept;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
|
@ -43,7 +55,6 @@ class LerpCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
const std::vector<kernel::AddressPtr> &)>;
|
||||
size_t output_size_{1};
|
||||
LerpFunc kernel_func_;
|
||||
std::string kernel_type_;
|
||||
std::vector<size_t> start_shape_;
|
||||
std::vector<size_t> end_shape_;
|
||||
std::vector<size_t> weight_shape_;
|
||||
|
|
Loading…
Reference in New Issue