forked from mindspore-Ecosystem/mindspore
!3543 Split unsupport transdata
Merge pull request !3543 from lianliguang/unify-primitive
This commit is contained in:
commit
7cb567ebbe
|
@ -158,13 +158,13 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor)
|
|||
|
||||
std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; }
|
||||
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType(
|
||||
const std::vector<std::vector<Axis>> &input_reshape_type) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
||||
kernel_build_info_->input_reshape_type_ = input_reshape_type;
|
||||
}
|
||||
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType(
|
||||
const std::vector<std::vector<Axis>> &output_reshape_type) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
||||
kernel_build_info_->output_reshape_type_ = output_reshape_type;
|
||||
|
@ -189,5 +189,36 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string
|
|||
}
|
||||
kernel_build_info_->outputs_format_[index] = format;
|
||||
}
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::vector<Axis> &input_reshape_type,
|
||||
size_t index) {
|
||||
if (index >= kernel_build_info_->input_reshape_type_.size()) {
|
||||
MS_LOG(EXCEPTION) << "index outof range!";
|
||||
}
|
||||
std::copy(input_reshape_type.begin(), input_reshape_type.end(),
|
||||
std::back_inserter(kernel_build_info_->input_reshape_type_[index]));
|
||||
}
|
||||
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::vector<Axis> &output_reshape_type,
|
||||
size_t index) {
|
||||
if (index >= kernel_build_info_->output_reshape_type_.size()) {
|
||||
MS_LOG(EXCEPTION) << "index outof range!";
|
||||
}
|
||||
std::copy(output_reshape_type.begin(), output_reshape_type.end(),
|
||||
std::back_inserter(kernel_build_info_->output_reshape_type_[index]));
|
||||
}
|
||||
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDeviceType(const TypeId &output_device_type, size_t index) {
|
||||
if (index >= kernel_build_info_->outputs_device_type_.size()) {
|
||||
MS_LOG(EXCEPTION) << "index outof range!";
|
||||
}
|
||||
kernel_build_info_->outputs_device_type_[index] = output_device_type;
|
||||
}
|
||||
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputDeviceType(const TypeId &input_device_type, size_t index) {
|
||||
if (index >= kernel_build_info_->inputs_device_type_.size()) {
|
||||
MS_LOG(EXCEPTION) << "index outof range!";
|
||||
}
|
||||
kernel_build_info_->inputs_device_type_[index] = input_device_type;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -71,6 +71,10 @@ class KernelBuildInfo {
|
|||
|
||||
std::vector<TypeId> GetAllOutputDeviceTypes() const;
|
||||
|
||||
std::vector<std::vector<Axis>> GetAllOutputReshapeType() const;
|
||||
|
||||
std::vector<std::vector<Axis>> GetAllInputReshapeType() const;
|
||||
|
||||
OpPattern op_pattern() const { return op_pattern_; }
|
||||
|
||||
FusionType fusion_type() const { return fusion_type_; }
|
||||
|
@ -109,7 +113,22 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
|
|||
KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); }
|
||||
|
||||
explicit KernelBuildInfoBuilder(std::shared_ptr<KernelBuildInfo> kernel_build_info)
|
||||
: kernel_build_info_(std::move(kernel_build_info)) {}
|
||||
: kernel_build_info_(std::make_shared<KernelBuildInfo>()) {
|
||||
SetKernelType(kernel_build_info->kernel_type());
|
||||
SetFusionType(kernel_build_info->fusion_type());
|
||||
SetProcessor(kernel_build_info->processor());
|
||||
OpPattern(kernel_build_info->op_pattern());
|
||||
for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) {
|
||||
kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index));
|
||||
kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index));
|
||||
kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index));
|
||||
}
|
||||
for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) {
|
||||
kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index));
|
||||
kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index));
|
||||
kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index));
|
||||
}
|
||||
}
|
||||
|
||||
~KernelBuildInfoBuilder() = default;
|
||||
|
||||
|
@ -123,9 +142,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
|
|||
|
||||
void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type);
|
||||
|
||||
void SetInputReshapeType(const std::vector<std::vector<Axis>> &input_reshape_type);
|
||||
void SetInputsReshapeType(const std::vector<std::vector<Axis>> &input_reshape_type);
|
||||
|
||||
void SetOutputReshapeType(const std::vector<std::vector<Axis>> &output_reshape_type);
|
||||
void SetOutputsReshapeType(const std::vector<std::vector<Axis>> &output_reshape_type);
|
||||
|
||||
void SetFusionType(FusionType fusion_type);
|
||||
|
||||
|
@ -137,6 +156,14 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
|
|||
|
||||
void SetOutputFormat(const std::string &format, size_t index);
|
||||
|
||||
void SetInputReshapeType(const std::vector<Axis> &input_reshape_type, size_t index);
|
||||
|
||||
void SetOutputReshapeType(const std::vector<Axis> &output_reshape_type, size_t index);
|
||||
|
||||
void SetInputDeviceType(const TypeId &input_device_type, size_t index);
|
||||
|
||||
void SetOutputDeviceType(const TypeId &output_device_type, size_t index);
|
||||
|
||||
std::shared_ptr<KernelBuildInfo> Build();
|
||||
|
||||
private:
|
||||
|
|
|
@ -118,7 +118,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
|
|||
}
|
||||
builder.SetInputsDeviceType(inputs_device_type);
|
||||
builder.SetInputsFormat(inputs_format);
|
||||
builder.SetInputReshapeType(inputs_reshape_type);
|
||||
builder.SetInputsReshapeType(inputs_reshape_type);
|
||||
// output
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_device_type;
|
||||
|
@ -129,7 +129,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
|
|||
}
|
||||
builder.SetOutputsDeviceType(outputs_device_type);
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
builder.SetOutputReshapeType(outputs_reshape_type);
|
||||
builder.SetOutputsReshapeType(outputs_reshape_type);
|
||||
kernel_info_list_->emplace_back(builder.Build());
|
||||
}
|
||||
MS_LOG(INFO) << "end.";
|
||||
|
|
|
@ -59,6 +59,7 @@
|
|||
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
|
||||
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
|
||||
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
|
||||
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
#include "backend/optimizer/pass/optimize_dependence.h"
|
||||
#include "backend/optimizer/pass/erase_visit_attr.h"
|
||||
|
@ -228,6 +229,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
|
|||
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<SplitUnsupportedTransData>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<RemoveInternalOutputCast>());
|
||||
optimizer->AddPassManager(mixed_precision_pm);
|
||||
|
|
|
@ -174,8 +174,8 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
|
|||
MS_EXCEPTION_IF_NULL(ori_build_info);
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
|
||||
builder->SetInputsFormat({input_format});
|
||||
builder->SetInputReshapeType({reshape_type});
|
||||
builder->SetOutputReshapeType({reshape_type});
|
||||
builder->SetInputsReshapeType({reshape_type});
|
||||
builder->SetOutputsReshapeType({reshape_type});
|
||||
builder->SetOutputsFormat({output_format});
|
||||
if (type_id != kTypeUnknown) {
|
||||
builder->SetOutputsDeviceType({type_id});
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef SplitUnsupportedTransData::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
return VectorRef({prim::KPrimTransData, X});
|
||||
}
|
||||
|
||||
const AnfNodePtr SplitUnsupportedTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto ori_trans_data = node->cast<CNodePtr>();
|
||||
if (AnfAlgo::GetCNodeName(ori_trans_data) != prim::KPrimTransData->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(ori_trans_data);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
if (kernel_info->GetInputNum() != 1 || kernel_info->GetOutputNum() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Transdata node's kernel info's input and output format size is not 1"
|
||||
<< ori_trans_data->DebugString();
|
||||
}
|
||||
return SplitTransData(func_graph, ori_trans_data);
|
||||
}
|
||||
AnfNodePtr SplitUnsupportedTransData::SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const {
|
||||
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node);
|
||||
if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() ||
|
||||
kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) {
|
||||
return trans_node;
|
||||
}
|
||||
auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
|
||||
auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
|
||||
builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT});
|
||||
std::vector<AnfNodePtr> next_trans_node_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::KPrimTransData->name())), trans_node};
|
||||
auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs);
|
||||
next_trans_node->set_abstract(trans_node->abstract());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), trans_node.get());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get());
|
||||
return next_trans_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SplitUnsupportedTransData : public PatternProcessPass {
|
||||
public:
|
||||
explicit SplitUnsupportedTransData(bool multigraph = true)
|
||||
: PatternProcessPass("split_unsupported_transdata", multigraph) {}
|
||||
~SplitUnsupportedTransData() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
|
|
@ -50,6 +50,8 @@ class TestHWInsertTransOp : public BackendCommon {
|
|||
KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({format, format});
|
||||
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
|
||||
builder.SetInputsReshapeType({{},{}});
|
||||
builder.SetOutputsReshapeType({});
|
||||
builder.SetOutputsFormat({format});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
|
@ -70,6 +72,8 @@ class TestHWInsertTransOp : public BackendCommon {
|
|||
EXPECT_NE(ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1), nullptr);
|
||||
auto max_pool = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1);
|
||||
KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsReshapeType({{}});
|
||||
builder.SetOutputsReshapeType({{},{}});
|
||||
builder.SetInputsFormat({kOpFormat_DEFAULT});
|
||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetOutputsFormat({format, format});
|
||||
|
@ -88,6 +92,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
|
|||
~MockInsertTransOpKernelSelectTrans4Dto5D() override = default;
|
||||
void SelectKernel(const CNodePtr &cnode) override {
|
||||
KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsReshapeType({{}});
|
||||
builder.SetOutputsReshapeType({{}});
|
||||
builder.SetInputsFormat({"NCHW"});
|
||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
|
|
|
@ -53,6 +53,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
|
|||
KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
|
||||
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
|
||||
builder.SetInputsReshapeType({{}, {}});
|
||||
builder.SetOutputsReshapeType({{}});
|
||||
builder.SetOutputsFormat({kOpFormat_NC1HWC0});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
|
@ -78,6 +80,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
|
|||
kg->AddInternalOutput(tuple_getitem1, max_pool);
|
||||
kg->AddInternalOutput(tuple_getitem2, max_pool);
|
||||
KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsReshapeType({{}});
|
||||
builder.SetOutputsReshapeType({{}, {}});
|
||||
builder.SetInputsFormat({kOpFormat_DEFAULT});
|
||||
builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
|
||||
|
@ -99,6 +103,8 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect {
|
|||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetInputsReshapeType({{}});
|
||||
builder.SetOutputsReshapeType({{}});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||
}
|
||||
};
|
||||
|
|
|
@ -51,6 +51,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
|
|||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetInputsReshapeType({});
|
||||
builder.SetOutputsReshapeType({});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||
} else {
|
||||
KernelBuildInfoBuilder builder;
|
||||
|
@ -58,7 +60,10 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
|
|||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetInputsReshapeType({});
|
||||
builder.SetOutputsReshapeType({});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -74,6 +79,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect {
|
|||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetOutputsFormat({"NCHW"});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetInputsReshapeType({{}});
|
||||
builder.SetOutputsReshapeType({{}});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||
} else {
|
||||
KernelBuildInfoBuilder builder;
|
||||
|
@ -81,6 +88,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect {
|
|||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetOutputsFormat({"NCHW"});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetInputsReshapeType({{}});
|
||||
builder.SetOutputsReshapeType({{}});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||
}
|
||||
}
|
||||
|
@ -116,6 +125,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
|
|||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetInputsReshapeType({{}});
|
||||
builder.SetOutputsReshapeType({{}});
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
kernel_info->set_select_kernel_build_info(builder.Build());
|
||||
transpose->set_kernel_info(kernel_info);
|
||||
|
@ -162,6 +173,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_nchw_fraz) {
|
|||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetInputsReshapeType({{}});
|
||||
builder.SetOutputsReshapeType({{}});
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
kernel_info->set_select_kernel_build_info(builder.Build());
|
||||
transpose->set_kernel_info(kernel_info);
|
||||
|
|
|
@ -58,6 +58,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
|
|||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetInputsReshapeType({});
|
||||
builder.SetOutputsReshapeType({});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||
} else {
|
||||
KernelBuildInfoBuilder builder;
|
||||
|
@ -65,6 +67,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
|
|||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetInputsReshapeType({});
|
||||
builder.SetOutputsReshapeType({});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
|
||||
}
|
||||
}
|
||||
|
@ -93,6 +97,8 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
|
|||
EXPECT_NE(transpose, nullptr);
|
||||
|
||||
KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsReshapeType({});
|
||||
builder.SetOutputsReshapeType({});
|
||||
builder.SetInputsFormat({"NCHW"});
|
||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
|
|
|
@ -56,6 +56,8 @@ class MockEliminate5To4And4To5KernelSelect : public KernelSelect {
|
|||
~MockEliminate5To4And4To5KernelSelect() override = default;
|
||||
void SelectKernel(const CNodePtr &cnode) override {
|
||||
KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsReshapeType({{}});
|
||||
builder.SetOutputsReshapeType({{}});
|
||||
builder.SetInputsFormat({"NCHW"});
|
||||
builder.SetInputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
|
@ -102,7 +104,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
|
|||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
|
||||
builder.SetInputsReshapeType({{}, {}});
|
||||
builder.SetOutputsReshapeType({{}});
|
||||
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
|
||||
|
@ -168,7 +171,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast) {
|
|||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
|
||||
builder.SetInputsReshapeType({{}, {}});
|
||||
builder.SetOutputsReshapeType({{}, {}});
|
||||
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
|
||||
|
@ -244,7 +248,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast_depend_cast) {
|
|||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
|
||||
builder.SetInputsReshapeType({{}, {}});
|
||||
builder.SetOutputsReshapeType({{}});
|
||||
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
add->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
|
||||
|
|
Loading…
Reference in New Issue