!48719 ge add dtype attr for tensorshape, dynamicshape and shape op

Merge pull request !48719 from TuDouNi/ge
This commit is contained in:
i-robot 2023-02-20 12:03:28 +00:00 committed by Gitee
commit b60a7804c0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 101 additions and 1 deletions

View File

@ -0,0 +1,63 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/ge/tensorshape_for_ge.h"
#include <vector>
#include <memory>
#include "include/common/utils/anfalgo.h"
#include "transform/graph_ir/transform_util.h"
namespace mindspore {
namespace opt {
namespace {
constexpr char kDtypeAttrName[] = "dtype";
} // namespace
const BaseRef TensorShapeForGE::DefinePattern() const {
VarPtr V = std::make_shared<Var>();
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({V, Xs});
}
// Set the attr dtype and convert it to ge_dtype
const AnfNodePtr TensorShapeForGE::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
static const PrimitiveSet need_dtype_attr_nodes = {prim::kPrimShape, prim::kPrimTensorShape, prim::kPrimDynamicShape};
PrimitivePtr prim = common::AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
if (need_dtype_attr_nodes.find(prim) == need_dtype_attr_nodes.end()) {
return nullptr;
}
if (common::AnfAlgo::HasNodeAttr(kDtypeAttrName, cnode)) {
return nullptr;
}
// get output dtype (ms_dtype)
TypeId output_dtype = common::AnfAlgo::GetOutputInferDataType(cnode, 0);
// convert to ge_dtype
int64_t ge_dtype = static_cast<int64_t>(transform::TransformUtil::ConvertDataType(output_dtype));
// update/set attr
common::AnfAlgo::SetNodeAttr(kDtypeAttrName, MakeValue(ge_dtype), cnode);
return node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_GE_OPTIMIZER_IRPASS_TENSORSHAPE_FOR_GE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_GE_OPTIMIZER_IRPASS_TENSORSHAPE_FOR_GE_H_
#include "backend/common/optimizer/optimizer.h"
namespace mindspore {
namespace opt {
class TensorShapeForGE : public PatternProcessPass {
public:
explicit TensorShapeForGE(bool multigraph = true) : PatternProcessPass("tensorshape_for_ge", multigraph) {}
~TensorShapeForGE() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_GE_OPTIMIZER_IRPASS_TENSORSHAPE_FOR_GE_H_

View File

@ -35,6 +35,7 @@
#include "plugin/device/ascend/optimizer/ge/sparse_softmax_cross_entropy_with_logits_split.h"
#include "plugin/device/ascend/optimizer/enhancer/add_placeholder_for_dynamic_gru.h"
#include "plugin/device/ascend/optimizer/enhancer/add_placeholder_for_dynamic_rnn.h"
#include "plugin/device/ascend/optimizer/ge/tensorshape_for_ge.h"
namespace mindspore {
namespace opt {
@ -73,6 +74,7 @@ void GeOptimization(const FuncGraphPtr &func_graph) {
pm->AddPass(std::make_shared<opt::ReduceAxisUpdate>());
pm->AddPass(std::make_shared<opt::InsertPlaceholderForDynamicGRUV2>());
pm->AddPass(std::make_shared<opt::InsertPlaceholderForDynamicRNN>());
pm->AddPass(std::make_shared<opt::TensorShapeForGE>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(func_graph);

View File

@ -83,6 +83,7 @@ const std::map<std::string, std::vector<std::pair<size_t, TypeId>>> kTransInputD
{kResizeNearestNeighborGradOpName, {{2, kNumberTypeInt32}}},
{kTileOpName, {{1, kNumberTypeInt32}}},
{kConv2DBackpropFilterOpName, {{3, kNumberTypeInt32}}},
{kConv2DBackpropInputOpName, {{3, kNumberTypeInt32}}},
{kOneHotOpName, {{2, kNumberTypeInt32}}},
{kLinSpaceOpName, {{3, kNumberTypeInt32}}}};

View File

@ -48,7 +48,7 @@ REG_ADPT_DESC(Data, kNameParam, ADPT_DESC(Data))
// Shape
INPUT_MAP(Shape) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Shape) = EMPTY_ATTR_MAP;
ATTR_MAP(Shape) = {{"dtype", ATTR_DESC(dtype, AnyTraits<int64_t>())}};
OUTPUT_MAP(Shape) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Shape, kNameShape, ADPT_DESC(Shape))