forked from mindspore-Ecosystem/mindspore
!48719 ge add dtype attr for tensorshape, dynamicshape and shape op
Merge pull request !48719 from TuDouNi/ge
This commit is contained in:
commit
b60a7804c0
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/optimizer/ge/tensorshape_for_ge.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "transform/graph_ir/transform_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr char kDtypeAttrName[] = "dtype";
|
||||
} // namespace
|
||||
|
||||
const BaseRef TensorShapeForGE::DefinePattern() const {
|
||||
VarPtr V = std::make_shared<Var>();
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({V, Xs});
|
||||
}
|
||||
|
||||
// Set the attr dtype and convert it to ge_dtype
|
||||
const AnfNodePtr TensorShapeForGE::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
static const PrimitiveSet need_dtype_attr_nodes = {prim::kPrimShape, prim::kPrimTensorShape, prim::kPrimDynamicShape};
|
||||
PrimitivePtr prim = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (need_dtype_attr_nodes.find(prim) == need_dtype_attr_nodes.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (common::AnfAlgo::HasNodeAttr(kDtypeAttrName, cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// get output dtype (ms_dtype)
|
||||
TypeId output_dtype = common::AnfAlgo::GetOutputInferDataType(cnode, 0);
|
||||
// convert to ge_dtype
|
||||
int64_t ge_dtype = static_cast<int64_t>(transform::TransformUtil::ConvertDataType(output_dtype));
|
||||
// update/set attr
|
||||
common::AnfAlgo::SetNodeAttr(kDtypeAttrName, MakeValue(ge_dtype), cnode);
|
||||
|
||||
return node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_GE_OPTIMIZER_IRPASS_TENSORSHAPE_FOR_GE_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_GE_OPTIMIZER_IRPASS_TENSORSHAPE_FOR_GE_H_
|
||||
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TensorShapeForGE : public PatternProcessPass {
|
||||
public:
|
||||
explicit TensorShapeForGE(bool multigraph = true) : PatternProcessPass("tensorshape_for_ge", multigraph) {}
|
||||
~TensorShapeForGE() override = default;
|
||||
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_GE_OPTIMIZER_IRPASS_TENSORSHAPE_FOR_GE_H_
|
|
@ -35,6 +35,7 @@
|
|||
#include "plugin/device/ascend/optimizer/ge/sparse_softmax_cross_entropy_with_logits_split.h"
|
||||
#include "plugin/device/ascend/optimizer/enhancer/add_placeholder_for_dynamic_gru.h"
|
||||
#include "plugin/device/ascend/optimizer/enhancer/add_placeholder_for_dynamic_rnn.h"
|
||||
#include "plugin/device/ascend/optimizer/ge/tensorshape_for_ge.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -73,6 +74,7 @@ void GeOptimization(const FuncGraphPtr &func_graph) {
|
|||
pm->AddPass(std::make_shared<opt::ReduceAxisUpdate>());
|
||||
pm->AddPass(std::make_shared<opt::InsertPlaceholderForDynamicGRUV2>());
|
||||
pm->AddPass(std::make_shared<opt::InsertPlaceholderForDynamicRNN>());
|
||||
pm->AddPass(std::make_shared<opt::TensorShapeForGE>());
|
||||
optimizer->AddPassManager(pm);
|
||||
|
||||
(void)optimizer->Optimize(func_graph);
|
||||
|
|
|
@ -83,6 +83,7 @@ const std::map<std::string, std::vector<std::pair<size_t, TypeId>>> kTransInputD
|
|||
{kResizeNearestNeighborGradOpName, {{2, kNumberTypeInt32}}},
|
||||
{kTileOpName, {{1, kNumberTypeInt32}}},
|
||||
{kConv2DBackpropFilterOpName, {{3, kNumberTypeInt32}}},
|
||||
{kConv2DBackpropInputOpName, {{3, kNumberTypeInt32}}},
|
||||
{kOneHotOpName, {{2, kNumberTypeInt32}}},
|
||||
{kLinSpaceOpName, {{3, kNumberTypeInt32}}}};
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ REG_ADPT_DESC(Data, kNameParam, ADPT_DESC(Data))
|
|||
|
||||
// Shape
|
||||
INPUT_MAP(Shape) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(Shape) = EMPTY_ATTR_MAP;
|
||||
ATTR_MAP(Shape) = {{"dtype", ATTR_DESC(dtype, AnyTraits<int64_t>())}};
|
||||
OUTPUT_MAP(Shape) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Shape, kNameShape, ADPT_DESC(Shape))
|
||||
|
||||
|
|
Loading…
Reference in New Issue