forked from mindspore-Ecosystem/mindspore
!5883 support for frac_zn_lstm
Merge pull request !5883 from liubuyu/master
This commit is contained in:
commit
7b3873559f
|
@ -60,6 +60,7 @@
|
|||
#include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h"
|
||||
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
|
||||
#include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h"
|
||||
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
|
||||
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
|
||||
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
|
||||
|
@ -286,6 +287,9 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
}
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForHcclOp>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AddInputToOutput>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertTranspose>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||
optimizer->AddPassManager(ir_fusion_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
|
|
@ -142,6 +142,15 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
|
|||
return node;
|
||||
}
|
||||
|
||||
void ReFreshInferShape(const AnfNodePtr &node, const std::string &op_name) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (op_name == kBasicLSTMCellWeightGradOpName && AnfAlgo::GetCNodeName(node) == prim::kPrimReshape->name()) {
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||
auto type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({type}, {{shape[0], shape[1]}}, node.get());
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -149,6 +158,10 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
size_t out_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
std::string op_name;
|
||||
if (node->isa<CNode>()) {
|
||||
op_name = AnfAlgo::GetCNodeName(node);
|
||||
}
|
||||
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
|
||||
if (output_format == kOpFormat_NC1KHKWHWC0) {
|
||||
|
@ -159,6 +172,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
||||
if (origin_shape.size() > 1 && kCommonFormatSet.find(output_format) == kCommonFormatSet.end()) {
|
||||
auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false);
|
||||
ReFreshInferShape(trans_op, op_name);
|
||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) {
|
||||
kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef InsertTranspose::DefinePattern() const {
|
||||
std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited);
|
||||
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({V, Xs});
|
||||
}
|
||||
|
||||
CNodePtr Insert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::string &op_name) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
CNodePtr new_node = nullptr;
|
||||
|
||||
std::vector<AnfNodePtr> transpose_inputs;
|
||||
auto prim = std::make_shared<Primitive>(prim::kPrimTranspose->name());
|
||||
transpose_inputs.push_back(NewValueNode(prim));
|
||||
|
||||
if (op_name == kBasicLSTMCellInputGradOpName) {
|
||||
auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 1);
|
||||
auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 1);
|
||||
auto dst_shape = {origin_shape[1], origin_shape[0]};
|
||||
transpose_inputs.push_back(AnfAlgo::GetInputNode(cnode, 1));
|
||||
CNodePtr transpose = func_graph->NewCNode(transpose_inputs);
|
||||
MS_EXCEPTION_IF_NULL(transpose);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {dst_shape}, transpose.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{1, 0}), transpose);
|
||||
AnfAlgo::SetNodeInput(cnode, transpose, 1);
|
||||
if (kernel_graph == nullptr) {
|
||||
new_node = std::make_shared<CNode>(*cnode);
|
||||
} else {
|
||||
new_node = kernel_graph->NewCNode(cnode);
|
||||
}
|
||||
|
||||
} else if (op_name == kBasicLSTMCellWeightGradOpName) {
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
size_t out_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t output_idx = 0; output_idx < out_num; output_idx++) {
|
||||
auto tuple_getitem = CreatTupleGetItemNode(func_graph, cnode, output_idx);
|
||||
auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx);
|
||||
if (origin_shape.size() > 1 && output_idx == 0) {
|
||||
auto dtype = AnfAlgo::GetOutputInferDataType(cnode, output_idx);
|
||||
auto dst_shape = {origin_shape[0], origin_shape[1]};
|
||||
transpose_inputs.push_back(tuple_getitem);
|
||||
CNodePtr transpose = func_graph->NewCNode(transpose_inputs);
|
||||
MS_EXCEPTION_IF_NULL(transpose);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {dst_shape}, transpose.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{1, 0}), transpose);
|
||||
make_tuple_inputs.push_back(transpose);
|
||||
} else {
|
||||
make_tuple_inputs.push_back(tuple_getitem);
|
||||
}
|
||||
}
|
||||
new_node = func_graph->NewCNode(make_tuple_inputs);
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
|
||||
const AnfNodePtr InsertTranspose::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
CNodePtr new_node = nullptr;
|
||||
if (op_name == kBasicLSTMCellInputGradOpName || op_name == kBasicLSTMCellWeightGradOpName) {
|
||||
new_node = Insert(func_graph, cnode, op_name);
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_BASICLSTM_OP_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_BASICLSTM_OP_H
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class InsertTranspose : public PatternProcessPass {
|
||||
public:
|
||||
explicit InsertTranspose(bool multigraph = true)
|
||||
: PatternProcessPass("insert_transpose_for_basiclstm_op", multigraph) {}
|
||||
~InsertTranspose() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_TRANSPOSE_FOR_BASICLSTM_OP_H
|
|
@ -24,6 +24,7 @@ namespace opt {
|
|||
const std::set<std::pair<string, string>> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW},
|
||||
{kOpFormat_NCHW, kOpFormat_C1HWNCoC0},
|
||||
{kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT},
|
||||
{kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM},
|
||||
{kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}};
|
||||
|
||||
bool TransDataSplit::Run(const FuncGraphPtr &func_graph) {
|
||||
|
@ -64,12 +65,13 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
|
|||
AnfNodePtr new_transdata_node = nullptr;
|
||||
AnfNodePtr new_transpose_node = nullptr;
|
||||
AnfNodePtr new_replace_node = nullptr;
|
||||
auto padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
|
||||
// if output_format=default transdata need split transdata->transpose else transpose->transdata
|
||||
if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) {
|
||||
// trans input_format to hwcn
|
||||
new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
|
||||
false, prim::KPrimTransData->name());
|
||||
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node);
|
||||
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node, padding_axis);
|
||||
// trans hwcn to default_format
|
||||
new_transpose_node =
|
||||
NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name());
|
||||
|
@ -86,7 +88,7 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
|
|||
// trans hwcn to output_format
|
||||
new_transdata_node =
|
||||
NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name());
|
||||
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node);
|
||||
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node, padding_axis);
|
||||
new_transdata_node->set_abstract(node->abstract());
|
||||
new_replace_node = new_transdata_node;
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@ inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx,
|
|||
template <typename T>
|
||||
T DivCeil(T n1, T n2) {
|
||||
if (n2 != 0) {
|
||||
return (n1 - 1) / n2 + 1;
|
||||
return (n1 + n2 - 1) / n2;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
@ -444,6 +444,17 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
} else if (format == kOpFormat_FRACTAL_ZN_LSTM) {
|
||||
const size_t c0 = 4;
|
||||
const size_t h = shape.at(kN) / c0;
|
||||
const size_t i = shape.at(kC) - h;
|
||||
const size_t first = DivCeil(i, kCubeSize) + DivCeil(h, kCubeSize);
|
||||
const size_t second = c0 * DivCeil(h, kCubeSize);
|
||||
device_shape.push_back(first);
|
||||
device_shape.push_back(second);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
if (shape.size() != kNchwDims) {
|
||||
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
|
||||
|
|
|
@ -216,6 +216,9 @@ constexpr auto kReduceMinOpName = "ReduceMin";
|
|||
constexpr auto kReduceMaxOpName = "ReduceMax";
|
||||
constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum";
|
||||
constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum";
|
||||
constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad";
|
||||
constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad";
|
||||
constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell";
|
||||
|
||||
// attr key name
|
||||
constexpr auto kAttrInputNames = "input_names";
|
||||
|
@ -344,10 +347,11 @@ constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0";
|
|||
constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04";
|
||||
constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04";
|
||||
constexpr auto kOpFormat_NDHWC = "NDHWC";
|
||||
constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM";
|
||||
const std::set<std::string> kOpFormatList = {
|
||||
kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC,
|
||||
kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
|
||||
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC};
|
||||
kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC,
|
||||
kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
|
||||
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM};
|
||||
const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN};
|
||||
const std::set<std::string> kOptOperatorSet = {
|
||||
kMomentumOpName,
|
||||
|
@ -373,9 +377,9 @@ const std::set<std::string> kOptOperatorSet = {
|
|||
kPullOpName,
|
||||
};
|
||||
|
||||
const std::set<std::string> kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
|
||||
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04,
|
||||
kOpFormat_FRACTAL_Z_C04};
|
||||
const std::set<std::string> kHWSpecialFormatSet = {
|
||||
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ,
|
||||
kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM};
|
||||
|
||||
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ basic_lstm_cell_op_info = TBERegOp("BasicLSTMCell") \
|
|||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "h", False, "required", "all") \
|
||||
.input(2, "c", False, "required", "all") \
|
||||
.input(3, "w", False, "required", "all") \
|
||||
.input(3, "w", False, "required", "all", reshape_type="CN") \
|
||||
.input(4, "b", False, "required", "all") \
|
||||
.input(5, "mask", False, "optional", "all") \
|
||||
.output(0, "ct", False, "required", "all") \
|
||||
|
@ -40,11 +40,11 @@ basic_lstm_cell_op_info = TBERegOp("BasicLSTMCell") \
|
|||
.output(4, "ft", False, "optional", "all") \
|
||||
.output(5, "ot", False, "optional", "all") \
|
||||
.output(6, "tanhct", False, "optional", "all") \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracZ,
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracZNLSTM,
|
||||
DataType.F32_Default, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
||||
DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ,
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZNLSTM,
|
||||
DataType.F16_Default, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ) \
|
||||
|
|
|
@ -25,7 +25,7 @@ basic_lstm_cell_input_grad_op_info = TBERegOp("BasicLSTMCellInputGrad") \
|
|||
.attr("keep_prob", "optional", "float", "all") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "dgate", False, "required", "all") \
|
||||
.input(1, "w", False, "required", "all") \
|
||||
.input(1, "w", False, "required", "all", reshape_type="NC") \
|
||||
.input(2, "dropout_mask", False, "optional", "all") \
|
||||
.output(0, "dxt", False, "required", "all") \
|
||||
.output(1, "dht", False, "required", "all") \
|
||||
|
|
|
@ -26,7 +26,7 @@ basic_lstm_cell_weight_grad_op_info = TBERegOp("BasicLSTMCellWeightGrad") \
|
|||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "h", False, "required", "all") \
|
||||
.input(2, "dgate", False, "required", "all") \
|
||||
.output(0, "dw", False, "required", "all") \
|
||||
.output(0, "dw", False, "required", "all", reshape_type="CN") \
|
||||
.output(1, "db", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ,
|
||||
DataType.F32_Default) \
|
||||
|
|
|
@ -129,6 +129,10 @@ trans_data_op_info = TBERegOp("TransData") \
|
|||
.dtype_format(DataType.F16_FracZ, DataType.F16_HWCN) \
|
||||
.dtype_format(DataType.F16_HWCN, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F32_HWCN, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F16_HWCN, DataType.F16_FracZNLSTM) \
|
||||
.dtype_format(DataType.F32_HWCN, DataType.F32_FracZNLSTM) \
|
||||
.dtype_format(DataType.F16_FracZNLSTM, DataType.F16_HWCN) \
|
||||
.dtype_format(DataType.F32_FracZNLSTM, DataType.F32_HWCN) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -619,6 +619,7 @@ class DataType:
|
|||
F16_NHWC = ("float16", "NHWC")
|
||||
F16_HWCN = ("float16", "HWCN")
|
||||
F16_NDHWC = ("float16", "NDHWC")
|
||||
F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
|
||||
|
||||
F32_None = ("float32", "")
|
||||
F32_Default = ("float32", "DefaultFormat")
|
||||
|
@ -630,6 +631,7 @@ class DataType:
|
|||
F32_NHWC = ("float32", "NHWC")
|
||||
F32_HWCN = ("float32", "HWCN")
|
||||
F32_NDHWC = ("float32", "NDHWC")
|
||||
F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
|
||||
|
||||
F64_None = ("float64", "")
|
||||
F64_Default = ("float64", "DefaultFormat")
|
||||
|
|
Loading…
Reference in New Issue