forked from mindspore-Ecosystem/mindspore
!3543 Split unsupport transdata
Merge pull request !3543 from lianliguang/unify-primitive
This commit is contained in:
commit
7cb567ebbe
|
@ -158,13 +158,13 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor)
|
||||||
|
|
||||||
std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; }
|
std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; }
|
||||||
|
|
||||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(
|
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType(
|
||||||
const std::vector<std::vector<Axis>> &input_reshape_type) {
|
const std::vector<std::vector<Axis>> &input_reshape_type) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
||||||
kernel_build_info_->input_reshape_type_ = input_reshape_type;
|
kernel_build_info_->input_reshape_type_ = input_reshape_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(
|
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType(
|
||||||
const std::vector<std::vector<Axis>> &output_reshape_type) {
|
const std::vector<std::vector<Axis>> &output_reshape_type) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
||||||
kernel_build_info_->output_reshape_type_ = output_reshape_type;
|
kernel_build_info_->output_reshape_type_ = output_reshape_type;
|
||||||
|
@ -189,5 +189,36 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string
|
||||||
}
|
}
|
||||||
kernel_build_info_->outputs_format_[index] = format;
|
kernel_build_info_->outputs_format_[index] = format;
|
||||||
}
|
}
|
||||||
|
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::vector<Axis> &input_reshape_type,
|
||||||
|
size_t index) {
|
||||||
|
if (index >= kernel_build_info_->input_reshape_type_.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "index outof range!";
|
||||||
|
}
|
||||||
|
std::copy(input_reshape_type.begin(), input_reshape_type.end(),
|
||||||
|
std::back_inserter(kernel_build_info_->input_reshape_type_[index]));
|
||||||
|
}
|
||||||
|
|
||||||
|
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::vector<Axis> &output_reshape_type,
|
||||||
|
size_t index) {
|
||||||
|
if (index >= kernel_build_info_->output_reshape_type_.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "index outof range!";
|
||||||
|
}
|
||||||
|
std::copy(output_reshape_type.begin(), output_reshape_type.end(),
|
||||||
|
std::back_inserter(kernel_build_info_->output_reshape_type_[index]));
|
||||||
|
}
|
||||||
|
|
||||||
|
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDeviceType(const TypeId &output_device_type, size_t index) {
|
||||||
|
if (index >= kernel_build_info_->outputs_device_type_.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "index outof range!";
|
||||||
|
}
|
||||||
|
kernel_build_info_->outputs_device_type_[index] = output_device_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputDeviceType(const TypeId &input_device_type, size_t index) {
|
||||||
|
if (index >= kernel_build_info_->inputs_device_type_.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "index outof range!";
|
||||||
|
}
|
||||||
|
kernel_build_info_->inputs_device_type_[index] = input_device_type;
|
||||||
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -71,6 +71,10 @@ class KernelBuildInfo {
|
||||||
|
|
||||||
std::vector<TypeId> GetAllOutputDeviceTypes() const;
|
std::vector<TypeId> GetAllOutputDeviceTypes() const;
|
||||||
|
|
||||||
|
std::vector<std::vector<Axis>> GetAllOutputReshapeType() const;
|
||||||
|
|
||||||
|
std::vector<std::vector<Axis>> GetAllInputReshapeType() const;
|
||||||
|
|
||||||
OpPattern op_pattern() const { return op_pattern_; }
|
OpPattern op_pattern() const { return op_pattern_; }
|
||||||
|
|
||||||
FusionType fusion_type() const { return fusion_type_; }
|
FusionType fusion_type() const { return fusion_type_; }
|
||||||
|
@ -109,7 +113,22 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
|
||||||
KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); }
|
KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); }
|
||||||
|
|
||||||
explicit KernelBuildInfoBuilder(std::shared_ptr<KernelBuildInfo> kernel_build_info)
|
explicit KernelBuildInfoBuilder(std::shared_ptr<KernelBuildInfo> kernel_build_info)
|
||||||
: kernel_build_info_(std::move(kernel_build_info)) {}
|
: kernel_build_info_(std::make_shared<KernelBuildInfo>()) {
|
||||||
|
SetKernelType(kernel_build_info->kernel_type());
|
||||||
|
SetFusionType(kernel_build_info->fusion_type());
|
||||||
|
SetProcessor(kernel_build_info->processor());
|
||||||
|
OpPattern(kernel_build_info->op_pattern());
|
||||||
|
for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) {
|
||||||
|
kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index));
|
||||||
|
kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index));
|
||||||
|
kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index));
|
||||||
|
}
|
||||||
|
for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) {
|
||||||
|
kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index));
|
||||||
|
kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index));
|
||||||
|
kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
~KernelBuildInfoBuilder() = default;
|
~KernelBuildInfoBuilder() = default;
|
||||||
|
|
||||||
|
@ -123,9 +142,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
|
||||||
|
|
||||||
void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type);
|
void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type);
|
||||||
|
|
||||||
void SetInputReshapeType(const std::vector<std::vector<Axis>> &input_reshape_type);
|
void SetInputsReshapeType(const std::vector<std::vector<Axis>> &input_reshape_type);
|
||||||
|
|
||||||
void SetOutputReshapeType(const std::vector<std::vector<Axis>> &output_reshape_type);
|
void SetOutputsReshapeType(const std::vector<std::vector<Axis>> &output_reshape_type);
|
||||||
|
|
||||||
void SetFusionType(FusionType fusion_type);
|
void SetFusionType(FusionType fusion_type);
|
||||||
|
|
||||||
|
@ -137,6 +156,14 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
|
||||||
|
|
||||||
void SetOutputFormat(const std::string &format, size_t index);
|
void SetOutputFormat(const std::string &format, size_t index);
|
||||||
|
|
||||||
|
void SetInputReshapeType(const std::vector<Axis> &input_reshape_type, size_t index);
|
||||||
|
|
||||||
|
void SetOutputReshapeType(const std::vector<Axis> &output_reshape_type, size_t index);
|
||||||
|
|
||||||
|
void SetInputDeviceType(const TypeId &input_device_type, size_t index);
|
||||||
|
|
||||||
|
void SetOutputDeviceType(const TypeId &output_device_type, size_t index);
|
||||||
|
|
||||||
std::shared_ptr<KernelBuildInfo> Build();
|
std::shared_ptr<KernelBuildInfo> Build();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -118,7 +118,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
|
||||||
}
|
}
|
||||||
builder.SetInputsDeviceType(inputs_device_type);
|
builder.SetInputsDeviceType(inputs_device_type);
|
||||||
builder.SetInputsFormat(inputs_format);
|
builder.SetInputsFormat(inputs_format);
|
||||||
builder.SetInputReshapeType(inputs_reshape_type);
|
builder.SetInputsReshapeType(inputs_reshape_type);
|
||||||
// output
|
// output
|
||||||
std::vector<std::string> outputs_format;
|
std::vector<std::string> outputs_format;
|
||||||
std::vector<TypeId> outputs_device_type;
|
std::vector<TypeId> outputs_device_type;
|
||||||
|
@ -129,7 +129,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
|
||||||
}
|
}
|
||||||
builder.SetOutputsDeviceType(outputs_device_type);
|
builder.SetOutputsDeviceType(outputs_device_type);
|
||||||
builder.SetOutputsFormat(outputs_format);
|
builder.SetOutputsFormat(outputs_format);
|
||||||
builder.SetOutputReshapeType(outputs_reshape_type);
|
builder.SetOutputsReshapeType(outputs_reshape_type);
|
||||||
kernel_info_list_->emplace_back(builder.Build());
|
kernel_info_list_->emplace_back(builder.Build());
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "end.";
|
MS_LOG(INFO) << "end.";
|
||||||
|
|
|
@ -59,6 +59,7 @@
|
||||||
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
|
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
|
||||||
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
|
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
|
||||||
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
|
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
|
||||||
|
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
|
||||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||||
#include "backend/optimizer/pass/optimize_dependence.h"
|
#include "backend/optimizer/pass/optimize_dependence.h"
|
||||||
#include "backend/optimizer/pass/erase_visit_attr.h"
|
#include "backend/optimizer/pass/erase_visit_attr.h"
|
||||||
|
@ -228,6 +229,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
|
||||||
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
|
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
|
||||||
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());
|
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());
|
||||||
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||||
|
mixed_precision_pm->AddPass(std::make_shared<SplitUnsupportedTransData>());
|
||||||
mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>());
|
mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>());
|
||||||
mixed_precision_pm->AddPass(std::make_shared<RemoveInternalOutputCast>());
|
mixed_precision_pm->AddPass(std::make_shared<RemoveInternalOutputCast>());
|
||||||
optimizer->AddPassManager(mixed_precision_pm);
|
optimizer->AddPassManager(mixed_precision_pm);
|
||||||
|
|
|
@ -174,8 +174,8 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
|
||||||
MS_EXCEPTION_IF_NULL(ori_build_info);
|
MS_EXCEPTION_IF_NULL(ori_build_info);
|
||||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
|
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
|
||||||
builder->SetInputsFormat({input_format});
|
builder->SetInputsFormat({input_format});
|
||||||
builder->SetInputReshapeType({reshape_type});
|
builder->SetInputsReshapeType({reshape_type});
|
||||||
builder->SetOutputReshapeType({reshape_type});
|
builder->SetOutputsReshapeType({reshape_type});
|
||||||
builder->SetOutputsFormat({output_format});
|
builder->SetOutputsFormat({output_format});
|
||||||
if (type_id != kTypeUnknown) {
|
if (type_id != kTypeUnknown) {
|
||||||
builder->SetOutputsDeviceType({type_id});
|
builder->SetOutputsDeviceType({type_id});
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
const BaseRef SplitUnsupportedTransData::DefinePattern() const {
|
||||||
|
VarPtr X = std::make_shared<Var>();
|
||||||
|
return VectorRef({prim::KPrimTransData, X});
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr SplitUnsupportedTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto ori_trans_data = node->cast<CNodePtr>();
|
||||||
|
if (AnfAlgo::GetCNodeName(ori_trans_data) != prim::KPrimTransData->name()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(ori_trans_data);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||||
|
if (kernel_info->GetInputNum() != 1 || kernel_info->GetOutputNum() != 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "Transdata node's kernel info's input and output format size is not 1"
|
||||||
|
<< ori_trans_data->DebugString();
|
||||||
|
}
|
||||||
|
return SplitTransData(func_graph, ori_trans_data);
|
||||||
|
}
|
||||||
|
AnfNodePtr SplitUnsupportedTransData::SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const {
|
||||||
|
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node);
|
||||||
|
if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() ||
|
||||||
|
kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) {
|
||||||
|
return trans_node;
|
||||||
|
}
|
||||||
|
auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
|
||||||
|
auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
|
||||||
|
builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT});
|
||||||
|
builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT});
|
||||||
|
std::vector<AnfNodePtr> next_trans_node_inputs = {
|
||||||
|
NewValueNode(std::make_shared<Primitive>(prim::KPrimTransData->name())), trans_node};
|
||||||
|
auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs);
|
||||||
|
next_trans_node->set_abstract(trans_node->abstract());
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), trans_node.get());
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get());
|
||||||
|
return next_trans_node;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
|
||||||
|
|
||||||
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class SplitUnsupportedTransData : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit SplitUnsupportedTransData(bool multigraph = true)
|
||||||
|
: PatternProcessPass("split_unsupported_transdata", multigraph) {}
|
||||||
|
~SplitUnsupportedTransData() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
AnfNodePtr SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
|
|
@ -50,6 +50,8 @@ class TestHWInsertTransOp : public BackendCommon {
|
||||||
KernelBuildInfoBuilder builder;
|
KernelBuildInfoBuilder builder;
|
||||||
builder.SetInputsFormat({format, format});
|
builder.SetInputsFormat({format, format});
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
|
||||||
|
builder.SetInputsReshapeType({{},{}});
|
||||||
|
builder.SetOutputsReshapeType({});
|
||||||
builder.SetOutputsFormat({format});
|
builder.SetOutputsFormat({format});
|
||||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||||
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
|
@ -70,6 +72,8 @@ class TestHWInsertTransOp : public BackendCommon {
|
||||||
EXPECT_NE(ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1), nullptr);
|
EXPECT_NE(ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1), nullptr);
|
||||||
auto max_pool = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1);
|
auto max_pool = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1);
|
||||||
KernelBuildInfoBuilder builder;
|
KernelBuildInfoBuilder builder;
|
||||||
|
builder.SetInputsReshapeType({{}});
|
||||||
|
builder.SetOutputsReshapeType({{},{}});
|
||||||
builder.SetInputsFormat({kOpFormat_DEFAULT});
|
builder.SetInputsFormat({kOpFormat_DEFAULT});
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||||
builder.SetOutputsFormat({format, format});
|
builder.SetOutputsFormat({format, format});
|
||||||
|
@ -88,6 +92,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
|
||||||
~MockInsertTransOpKernelSelectTrans4Dto5D() override = default;
|
~MockInsertTransOpKernelSelectTrans4Dto5D() override = default;
|
||||||
void SelectKernel(const CNodePtr &cnode) override {
|
void SelectKernel(const CNodePtr &cnode) override {
|
||||||
KernelBuildInfoBuilder builder;
|
KernelBuildInfoBuilder builder;
|
||||||
|
builder.SetInputsReshapeType({{}});
|
||||||
|
builder.SetOutputsReshapeType({{}});
|
||||||
builder.SetInputsFormat({"NCHW"});
|
builder.SetInputsFormat({"NCHW"});
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||||
builder.SetOutputsFormat({"NC1HWC0"});
|
builder.SetOutputsFormat({"NC1HWC0"});
|
||||||
|
|
|
@ -53,6 +53,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
|
||||||
KernelBuildInfoBuilder builder;
|
KernelBuildInfoBuilder builder;
|
||||||
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
|
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
|
||||||
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
|
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
|
||||||
|
builder.SetInputsReshapeType({{}, {}});
|
||||||
|
builder.SetOutputsReshapeType({{}});
|
||||||
builder.SetOutputsFormat({kOpFormat_NC1HWC0});
|
builder.SetOutputsFormat({kOpFormat_NC1HWC0});
|
||||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||||
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
|
@ -78,6 +80,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
|
||||||
kg->AddInternalOutput(tuple_getitem1, max_pool);
|
kg->AddInternalOutput(tuple_getitem1, max_pool);
|
||||||
kg->AddInternalOutput(tuple_getitem2, max_pool);
|
kg->AddInternalOutput(tuple_getitem2, max_pool);
|
||||||
KernelBuildInfoBuilder builder;
|
KernelBuildInfoBuilder builder;
|
||||||
|
builder.SetInputsReshapeType({{}});
|
||||||
|
builder.SetOutputsReshapeType({{}, {}});
|
||||||
builder.SetInputsFormat({kOpFormat_DEFAULT});
|
builder.SetInputsFormat({kOpFormat_DEFAULT});
|
||||||
builder.SetInputsDeviceType({kFloat32->type_id()});
|
builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||||
builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
|
builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
|
||||||
|
@ -99,6 +103,8 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect {
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||||
builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||||
|
builder.SetInputsReshapeType({{}});
|
||||||
|
builder.SetOutputsReshapeType({{}});
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -51,6 +51,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||||
builder.SetOutputsFormat({"NC1HWC0"});
|
builder.SetOutputsFormat({"NC1HWC0"});
|
||||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||||
|
builder.SetInputsReshapeType({});
|
||||||
|
builder.SetOutputsReshapeType({});
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||||
} else {
|
} else {
|
||||||
KernelBuildInfoBuilder builder;
|
KernelBuildInfoBuilder builder;
|
||||||
|
@ -58,7 +60,10 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||||
builder.SetOutputsFormat({"NC1HWC0"});
|
builder.SetOutputsFormat({"NC1HWC0"});
|
||||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||||
|
builder.SetInputsReshapeType({});
|
||||||
|
builder.SetOutputsReshapeType({});
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -74,6 +79,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect {
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||||
builder.SetOutputsFormat({"NCHW"});
|
builder.SetOutputsFormat({"NCHW"});
|
||||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||||
|
builder.SetInputsReshapeType({{}});
|
||||||
|
builder.SetOutputsReshapeType({{}});
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||||
} else {
|
} else {
|
||||||
KernelBuildInfoBuilder builder;
|
KernelBuildInfoBuilder builder;
|
||||||
|
@ -81,6 +88,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect {
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||||
builder.SetOutputsFormat({"NCHW"});
|
builder.SetOutputsFormat({"NCHW"});
|
||||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||||
|
builder.SetInputsReshapeType({{}});
|
||||||
|
builder.SetOutputsReshapeType({{}});
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -116,6 +125,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
|
||||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||||
builder.SetProcessor(kernel::Processor::AICORE);
|
builder.SetProcessor(kernel::Processor::AICORE);
|
||||||
|
builder.SetInputsReshapeType({{}});
|
||||||
|
builder.SetOutputsReshapeType({{}});
|
||||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||||
kernel_info->set_select_kernel_build_info(builder.Build());
|
kernel_info->set_select_kernel_build_info(builder.Build());
|
||||||
transpose->set_kernel_info(kernel_info);
|
transpose->set_kernel_info(kernel_info);
|
||||||
|
@ -162,6 +173,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_nchw_fraz) {
|
||||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||||
builder.SetProcessor(kernel::Processor::AICORE);
|
builder.SetProcessor(kernel::Processor::AICORE);
|
||||||
|
builder.SetInputsReshapeType({{}});
|
||||||
|
builder.SetOutputsReshapeType({{}});
|
||||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||||
kernel_info->set_select_kernel_build_info(builder.Build());
|
kernel_info->set_select_kernel_build_info(builder.Build());
|
||||||
transpose->set_kernel_info(kernel_info);
|
transpose->set_kernel_info(kernel_info);
|
||||||
|
|
|
@ -58,6 +58,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||||
builder.SetOutputsFormat({"NC1HWC0"});
|
builder.SetOutputsFormat({"NC1HWC0"});
|
||||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||||
|
builder.SetInputsReshapeType({});
|
||||||
|
builder.SetOutputsReshapeType({});
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||||
} else {
|
} else {
|
||||||
KernelBuildInfoBuilder builder;
|
KernelBuildInfoBuilder builder;
|
||||||
|
@ -65,6 +67,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||||
builder.SetOutputsFormat({"NC1HWC0"});
|
builder.SetOutputsFormat({"NC1HWC0"});
|
||||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||||
|
builder.SetInputsReshapeType({});
|
||||||
|
builder.SetOutputsReshapeType({});
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -93,6 +97,8 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
|
||||||
EXPECT_NE(transpose, nullptr);
|
EXPECT_NE(transpose, nullptr);
|
||||||
|
|
||||||
KernelBuildInfoBuilder builder;
|
KernelBuildInfoBuilder builder;
|
||||||
|
builder.SetInputsReshapeType({});
|
||||||
|
builder.SetOutputsReshapeType({});
|
||||||
builder.SetInputsFormat({"NCHW"});
|
builder.SetInputsFormat({"NCHW"});
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||||
builder.SetOutputsFormat({"NC1HWC0"});
|
builder.SetOutputsFormat({"NC1HWC0"});
|
||||||
|
|
|
@ -56,6 +56,8 @@ class MockEliminate5To4And4To5KernelSelect : public KernelSelect {
|
||||||
~MockEliminate5To4And4To5KernelSelect() override = default;
|
~MockEliminate5To4And4To5KernelSelect() override = default;
|
||||||
void SelectKernel(const CNodePtr &cnode) override {
|
void SelectKernel(const CNodePtr &cnode) override {
|
||||||
KernelBuildInfoBuilder builder;
|
KernelBuildInfoBuilder builder;
|
||||||
|
builder.SetInputsReshapeType({{}});
|
||||||
|
builder.SetOutputsReshapeType({{}});
|
||||||
builder.SetInputsFormat({"NCHW"});
|
builder.SetInputsFormat({"NCHW"});
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||||
builder.SetOutputsFormat({"NC1HWC0"});
|
builder.SetOutputsFormat({"NC1HWC0"});
|
||||||
|
@ -102,7 +104,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
|
||||||
builder.SetOutputsFormat({"NC1HWC0"});
|
builder.SetOutputsFormat({"NC1HWC0"});
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
|
||||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||||
|
builder.SetInputsReshapeType({{}, {}});
|
||||||
|
builder.SetOutputsReshapeType({{}});
|
||||||
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
|
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
|
||||||
|
@ -168,7 +171,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast) {
|
||||||
builder.SetOutputsFormat({"NC1HWC0"});
|
builder.SetOutputsFormat({"NC1HWC0"});
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
|
||||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||||
|
builder.SetInputsReshapeType({{}, {}});
|
||||||
|
builder.SetOutputsReshapeType({{}, {}});
|
||||||
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
|
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
|
||||||
|
@ -244,7 +248,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast_depend_cast) {
|
||||||
builder.SetOutputsFormat({"NC1HWC0"});
|
builder.SetOutputsFormat({"NC1HWC0"});
|
||||||
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
|
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
|
||||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||||
|
builder.SetInputsReshapeType({{}, {}});
|
||||||
|
builder.SetOutputsReshapeType({{}});
|
||||||
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
|
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
|
||||||
|
|
Loading…
Reference in New Issue