!3129 Decouple ir from frontend

Merge pull request !3129 from hewei/decouple_ir_frontend
This commit is contained in:
mindspore-ci-bot 2020-07-21 11:35:33 +08:00 committed by Gitee
commit a2bf5a322e
27 changed files with 465 additions and 681 deletions

View File

@ -27,6 +27,7 @@
#include "runtime/device/kernel_info.h"
#include "utils/graph_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "frontend/parallel/ops_info/operator_info.h"
namespace mindspore {
const std::string ToShortString(const TypeId &typeId) {
@ -266,7 +267,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo
return;
}
auto operator_info = node->operator_info();
auto operator_info = node->GetUserData<parallel::OperatorInfo>();
if (operator_info == nullptr) {
return;
}

View File

@ -437,7 +437,7 @@ static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr) {
return;
}
auto distributed_operation_info = node->operator_info();
auto distributed_operation_info = node->GetUserData<parallel::OperatorInfo>();
if (distributed_operation_info != nullptr) {
auto strategyPtr = distributed_operation_info->strategy();
if (strategyPtr != nullptr) {

View File

@ -1,293 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/operator/ops.h"
#include <memory>
#include <string>
namespace mindspore {
// namespace to support primitive operators
namespace prim {
// Arithmetic
const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add");
const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub");
const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul");
const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div");
const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>("scalar_floordiv");
const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod");
const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow");
const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc");
const PrimitivePtr kPrimScalarFloor = std::make_shared<Primitive>("scalar_floor");
const PrimitivePtr kPrimScalarUadd = std::make_shared<Primitive>("scalar_uadd");
const PrimitivePtr kPrimScalarUsub = std::make_shared<Primitive>("scalar_usub");
const PrimitivePtr kPrimScalarExp = std::make_shared<Primitive>("scalar_exp");
const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_log");
const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin");
const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos");
const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan");
// Comparisons
const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq");
const PrimitivePtr kPrimScalarLt = std::make_shared<Primitive>("scalar_lt");
const PrimitivePtr kPrimScalarGt = std::make_shared<Primitive>("scalar_gt");
const PrimitivePtr kPrimScalarNe = std::make_shared<Primitive>("scalar_ne");
const PrimitivePtr kPrimScalarLe = std::make_shared<Primitive>("scalar_le");
const PrimitivePtr kPrimScalarGe = std::make_shared<Primitive>("scalar_ge");
const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not");
const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and");
const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or");
const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq");
const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater");
const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual");
const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual");
// Type introspection
const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
// Statements
const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch");
const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer");
const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");
const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign");
const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd");
const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub");
const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select");
const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call");
const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot");
const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col");
const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im");
const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1");
const PrimitivePtr kPrimCol2ImV1 = std::make_shared<Primitive>("col2im_v1");
const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve");
const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGoto");
const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch");
const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet");
// Structure
const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple");
const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict");
const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg");
const PrimitivePtr kPrimExtractKeywordArg = std::make_shared<Primitive>("extract_keyword_arg");
const PrimitivePtr kPrimMakeSlice = std::make_shared<Primitive>("make_slice");
const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
const PrimitivePtr kPrimTupleGetItem = std::make_shared<Primitive>("tuple_getitem");
const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem");
const PrimitivePtr kPrimArrayGetItem = std::make_shared<Primitive>("array_getitem");
const PrimitivePtr kPrimTupleSetItem = std::make_shared<Primitive>("tuple_setitem");
const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem");
const PrimitivePtr kPrimArraySetItem = std::make_shared<Primitive>("array_setitem");
const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem");
const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem");
const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append");
const PrimitivePtr kPrimGetAttr = std::make_shared<Primitive>("getattr");
const PrimitivePtr kPrimTupleLen = std::make_shared<Primitive>("tuple_len");
const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len");
const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");
const PrimitivePtr kPrimArrayLen = std::make_shared<Primitive>("array_len");
const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed");
const PrimitivePtr kPrimTileShape = std::make_shared<Primitive>("tile_shape");
const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape");
const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div");
const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array");
const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul");
const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared<Primitive>("generate_shape_index");
const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index");
const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal");
const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal");
const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range");
const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient");
// Arrays
const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array");
const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar");
const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape");
const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map");
const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce");
const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup");
const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad");
const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax");
const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack");
const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum");
const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");
const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData");
const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask");
const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");
const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue");
// Maths
const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");
const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");
const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad");
const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean");
const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum");
const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll");
const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax");
const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin");
const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg");
const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub");
const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum");
const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum");
const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square");
const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum");
const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd");
const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar");
const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd");
const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub");
const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow");
const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv");
const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal");
const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims");
// NN
const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax");
const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax");
const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad");
const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh");
const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad");
const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling");
const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad");
const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool");
const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad");
const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp");
const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");
const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad");
const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad");
const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad");
const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput");
const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter");
const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative");
const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter =
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter");
const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput =
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput");
const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad");
const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits");
const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits =
std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits");
const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum");
const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum");
const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm");
const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad");
const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop");
const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask");
const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask");
const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");
const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp");
// Other miscellaneous
const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");
const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("Partial");
const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem");
const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem");
const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward");
const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");
const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape");
const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print");
const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend");
const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend");
const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_");
const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat");
// Comm ops
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
// Debug ops
const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary");
const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug");
// IndexedSlices
const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices");
const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
// SparseTensor
const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor");
const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues");
const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
} // namespace prim
} // namespace mindspore

View File

@ -22,6 +22,7 @@
#include <memory>
#include "ir/anf.h"
#include "ir/primitive.h"
#include "base/core_ops.h"
namespace mindspore {
// namespace to support primitive operators
@ -31,273 +32,158 @@ ValuePtr GetPythonOps(const std::string &op_name,
bool use_signature = false);
// Arithmetic
extern const PrimitivePtr kPrimScalarAdd;
extern const PrimitivePtr kPrimScalarSub;
extern const PrimitivePtr kPrimScalarMul;
extern const PrimitivePtr kPrimScalarDiv;
extern const PrimitivePtr kPrimScalarFloordiv;
extern const PrimitivePtr kPrimScalarMod;
extern const PrimitivePtr kPrimScalarPow;
extern const PrimitivePtr kPrimScalarTrunc;
extern const PrimitivePtr kPrimScalarFloor;
extern const PrimitivePtr kPrimScalarUadd;
extern const PrimitivePtr kPrimScalarUsub;
extern const PrimitivePtr kPrimScalarExp;
extern const PrimitivePtr kPrimScalarLog;
extern const PrimitivePtr kPrimScalarSin;
extern const PrimitivePtr kPrimScalarCos;
extern const PrimitivePtr kPrimScalarTan;
inline const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add");
inline const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub");
inline const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul");
inline const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div");
inline const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>("scalar_floordiv");
inline const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod");
inline const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow");
inline const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc");
inline const PrimitivePtr kPrimScalarFloor = std::make_shared<Primitive>("scalar_floor");
inline const PrimitivePtr kPrimScalarUadd = std::make_shared<Primitive>("scalar_uadd");
inline const PrimitivePtr kPrimScalarUsub = std::make_shared<Primitive>("scalar_usub");
inline const PrimitivePtr kPrimScalarExp = std::make_shared<Primitive>("scalar_exp");
inline const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_log");
inline const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin");
inline const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos");
inline const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan");
// Comparisons
extern const PrimitivePtr kPrimScalarEq;
extern const PrimitivePtr kPrimScalarLt;
extern const PrimitivePtr kPrimScalarGt;
extern const PrimitivePtr kPrimScalarNe;
extern const PrimitivePtr kPrimScalarLe;
extern const PrimitivePtr kPrimScalarGe;
extern const PrimitivePtr kPrimBoolNot;
extern const PrimitivePtr kPrimBoolAnd;
extern const PrimitivePtr kPrimBoolOr;
extern const PrimitivePtr kPrimBoolEq;
extern const PrimitivePtr kPrimGreater;
extern const PrimitivePtr kPrimGreaterEqual;
extern const PrimitivePtr kPrimLess;
extern const PrimitivePtr kPrimLessEqual;
extern const PrimitivePtr kPrimEqual;
extern const PrimitivePtr kPrimNotEqual;
inline const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq");
inline const PrimitivePtr kPrimScalarLt = std::make_shared<Primitive>("scalar_lt");
inline const PrimitivePtr kPrimScalarGt = std::make_shared<Primitive>("scalar_gt");
inline const PrimitivePtr kPrimScalarNe = std::make_shared<Primitive>("scalar_ne");
inline const PrimitivePtr kPrimScalarLe = std::make_shared<Primitive>("scalar_le");
inline const PrimitivePtr kPrimScalarGe = std::make_shared<Primitive>("scalar_ge");
inline const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not");
inline const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and");
inline const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or");
inline const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq");
inline const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater");
inline const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual");
inline const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
inline const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
inline const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual");
// Type introspection
extern const PrimitivePtr kPrimTypeOf;
extern const PrimitivePtr kPrimHasType;
inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
// Statements
extern const PrimitivePtr kPrimSwitch;
extern const PrimitivePtr kPrimSwitchLayer;
extern const PrimitivePtr kPrimReturn;
extern const PrimitivePtr kPrimAssign;
extern const PrimitivePtr kPrimAssignAdd;
extern const PrimitivePtr kPrimAssignSub;
extern const PrimitivePtr kPrimSelect;
extern const PrimitivePtr kPrimCall;
inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
inline const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot");
inline const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col");
inline const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im");
inline const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1");
inline const PrimitivePtr kPrimCol2ImV1 = std::make_shared<Primitive>("col2im_v1");
extern const PrimitivePtr kPrimDistribute;
extern const PrimitivePtr kPrimDot;
extern const PrimitivePtr kPrimIm2Col;
extern const PrimitivePtr kPrimCol2Im;
extern const PrimitivePtr kPrimIm2ColV1;
extern const PrimitivePtr kPrimCol2ImV1;
inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve");
inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
extern const PrimitivePtr kPrimResolve;
extern const PrimitivePtr kPrimEmbed;
extern const PrimitivePtr kPrimRefToEmbed;
extern const PrimitivePtr kPrimCreateInstance;
extern const PrimitivePtr kPrimLabelGoto;
extern const PrimitivePtr kPrimLabelSwitch;
extern const PrimitivePtr kPrimLabelSet;
// Structure
extern const PrimitivePtr kPrimStringEqual;
extern const PrimitivePtr kPrimStringConcat;
extern const PrimitivePtr kPrimMakeTuple;
extern const PrimitivePtr kPrimMakeList;
extern const PrimitivePtr kPrimMakeDict;
extern const PrimitivePtr kPrimMakeKeywordArg;
extern const PrimitivePtr kPrimExtractKeywordArg;
extern const PrimitivePtr kPrimMakeSlice;
extern const PrimitivePtr kPrimMakeRecord;
extern const PrimitivePtr kPrimTupleGetItem;
extern const PrimitivePtr kPrimListGetItem;
extern const PrimitivePtr kPrimArrayGetItem;
extern const PrimitivePtr kPrimTupleSetItem;
extern const PrimitivePtr kPrimListSetItem;
extern const PrimitivePtr kPrimArraySetItem;
extern const PrimitivePtr kPrimDictGetItem;
extern const PrimitivePtr kPrimDictSetItem;
extern const PrimitivePtr kPrimListAppend;
extern const PrimitivePtr kPrimGetAttr;
extern const PrimitivePtr kPrimTupleLen;
extern const PrimitivePtr kPrimDictLen;
extern const PrimitivePtr kPrimListLen;
extern const PrimitivePtr kPrimArrayLen;
extern const PrimitivePtr kPrimListMap;
extern const PrimitivePtr kPrimListReduce;
extern const PrimitivePtr kPrimTupleReversed;
extern const PrimitivePtr kPrimTileShape;
extern const PrimitivePtr kPrimReducedShape;
extern const PrimitivePtr kPrimTupleDiv;
extern const PrimitivePtr kPrimTupleToArray;
extern const PrimitivePtr kPrimShapeMul;
extern const PrimitivePtr kPrimGenerateShapeIndex;
extern const PrimitivePtr kPrimGenerateInverseIndex;
extern const PrimitivePtr kPrimTupleEqual;
extern const PrimitivePtr kPrimListEqual;
extern const PrimitivePtr kPrimMakeRange;
extern const PrimitivePtr kPrimStopGradient;
inline const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGoto");
inline const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch");
inline const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet");
// Arrays
extern const PrimitivePtr kPrimScalarToArray;
extern const PrimitivePtr kPrimArrayToScalar;
extern const PrimitivePtr kPrimBroadcastShape;
extern const PrimitivePtr kPrimArrayMap;
extern const PrimitivePtr kPrimArrayReduce;
extern const PrimitivePtr kPrimShape;
extern const PrimitivePtr kPrimCast;
extern const PrimitivePtr kPrimConcat;
extern const PrimitivePtr kPrimSqueeze;
extern const PrimitivePtr kPrimTranspose;
extern const PrimitivePtr kPrimGatherV2;
extern const PrimitivePtr kPrimEmbeddingLookup;
extern const PrimitivePtr kPrimEmbeddingLookupCommGrad;
extern const PrimitivePtr kPrimSize;
extern const PrimitivePtr kPrimArgMax;
extern const PrimitivePtr kPrimPack;
extern const PrimitivePtr kPrimUnpack;
extern const PrimitivePtr kPrimUnsortedSegmentMin;
extern const PrimitivePtr kPrimUnsortedSegmentSum;
extern const PrimitivePtr kPrimConcatOffset;
extern const PrimitivePtr kPrimReshape;
extern const PrimitivePtr kPrimTile;
extern const PrimitivePtr kPrimAddN;
extern const PrimitivePtr KPrimTransData;
extern const PrimitivePtr kPrimNMSWithMask;
extern const PrimitivePtr kPrimPad;
extern const PrimitivePtr kPrimArgMaxWithValue;
extern const PrimitivePtr kPrimRealDiv;
extern const PrimitivePtr kPrimSqrt;
extern const PrimitivePtr kPrimReciprocal;
extern const PrimitivePtr kPrimExpandDims;
// Maths
extern const PrimitivePtr kPrimTensorAdd;
extern const PrimitivePtr kPrimMatMul;
extern const PrimitivePtr kPrimBatchMatMul;
extern const PrimitivePtr kPrimMaximumGrad;
extern const PrimitivePtr kPrimMinimumGrad;
extern const PrimitivePtr kPrimReduceMean;
extern const PrimitivePtr kPrimReduceSum;
extern const PrimitivePtr kPrimReduceAll;
extern const PrimitivePtr kPrimReduceMax;
extern const PrimitivePtr kPrimReduceMin;
extern const PrimitivePtr kPrimNeg;
extern const PrimitivePtr kPrimSub;
extern const PrimitivePtr kPrimMul;
extern const PrimitivePtr kPrimRealDiv;
extern const PrimitivePtr kPrimMinimum;
extern const PrimitivePtr kPrimMaximum;
extern const PrimitivePtr kPrimSquare;
extern const PrimitivePtr kPrimSqrt;
extern const PrimitivePtr kPrimEqual;
extern const PrimitivePtr kPrimLess;
extern const PrimitivePtr kPrimLessEqual;
extern const PrimitivePtr kPrimCumSum;
extern const PrimitivePtr kPrimCumProd;
extern const PrimitivePtr kPrimSubscalar;
extern const PrimitivePtr kPrimInplaceAdd;
extern const PrimitivePtr kPrimInplaceSub;
extern const PrimitivePtr kPrimPow;
inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array");
inline const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar");
inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape");
inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map");
inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce");
inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup");
inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad");
inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
inline const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax");
inline const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack");
inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum");
inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");
inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
inline const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData");
inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask");
inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");
inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue");
// NN
extern const PrimitivePtr kPrimFlatten;
extern const PrimitivePtr kPrimSoftmax;
extern const PrimitivePtr kPrimLogSoftmax;
extern const PrimitivePtr kPrimLogSoftmaxGrad;
extern const PrimitivePtr kPrimApplyCenteredRMSProp;
extern const PrimitivePtr kPrimTanh;
extern const PrimitivePtr kPrimTanhGrad;
extern const PrimitivePtr kPrimPooling;
extern const PrimitivePtr kPrimPoolingGrad;
extern const PrimitivePtr kPrimFusedBatchNorm;
extern const PrimitivePtr kPrimBatchNorm;
extern const PrimitivePtr kPrimBatchNormGrad;
extern const PrimitivePtr kPrimConv2D;
extern const PrimitivePtr kPrimMaxPool;
extern const PrimitivePtr kPrimMaxPoolGrad;
extern const PrimitivePtr kPrimAvgPoolGrad;
extern const PrimitivePtr kPrimFusedBatchNormGrad;
extern const PrimitivePtr kPrimReluGrad;
extern const PrimitivePtr kPrimConv2DBackpropInput;
extern const PrimitivePtr kPrimConv2DBackpropFilter;
extern const PrimitivePtr kPrimDepthwiseConv2dNative;
extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter;
extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput;
extern const PrimitivePtr kPrimBiasAddGrad;
extern const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits;
extern const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits;
extern const PrimitivePtr kPrimMomentum;
extern const PrimitivePtr kPrimApplyMomentum;
extern const PrimitivePtr kPrimLayerNorm;
extern const PrimitivePtr kPrimLayerNormGrad;
extern const PrimitivePtr kPrimLayerNormXBackprop;
extern const PrimitivePtr kPrimLayerNormBetaGammaBackprop;
extern const PrimitivePtr kPrimDropoutGenMask;
extern const PrimitivePtr kPrimDropoutDoMask;
extern const PrimitivePtr kPrimOneHot;
extern const PrimitivePtr kPrimGelu;
extern const PrimitivePtr kPrimGeluGrad;
extern const PrimitivePtr kPrimRelu;
extern const PrimitivePtr kPrimReluV2;
extern const PrimitivePtr kPrimActivation;
extern const PrimitivePtr kPrimZerosLike;
extern const PrimitivePtr kPrimFakeBprop;
extern const PrimitivePtr kPrimBpropCut;
extern const PrimitivePtr kPrimFakeQuantPerLayer;
extern const PrimitivePtr kPrimFakeQuantPerChannel;
extern const PrimitivePtr kPrimApplyRMSProp;
// Other Miscellaneous
extern const PrimitivePtr kPrimIdentity;
extern const PrimitivePtr kPrimPartial;
extern const PrimitivePtr kPrimJ;
extern const PrimitivePtr kPrimEnvSetItem;
extern const PrimitivePtr kPrimEnvGetItem;
extern const PrimitivePtr kPrimEnvAdd;
extern const PrimitivePtr kPrimMakeRefKey;
extern const PrimitivePtr kPrimMakeRef;
extern const PrimitivePtr kPrimGetRefKey;
extern const PrimitivePtr kPrimGetRefValue;
extern const PrimitivePtr kPrimGetRefOrigin;
extern const PrimitivePtr kPrimInsertGradientOf;
extern const PrimitivePtr kPrimHookBackward;
extern const PrimitivePtr kPrimPrintShapeType;
extern const PrimitivePtr kPrimPrint;
extern const PrimitivePtr kPrimSameTypeShape;
extern const PrimitivePtr kPrimCheckBprop;
extern const PrimitivePtr kPrimDepend;
extern const PrimitivePtr kPrimStateSetItem;
extern const PrimitivePtr kPrimScalarSummary;
extern const PrimitivePtr kPrimImageSummary;
extern const PrimitivePtr kPrimTensorSummary;
extern const PrimitivePtr kPrimHistogramSummary;
extern const PrimitivePtr kPrimBroadcastGradientArgs;
extern const PrimitivePtr kPrimControlDepend;
extern const PrimitivePtr kPrimIs_;
extern const PrimitivePtr kPrimIsNot;
extern const PrimitivePtr kPrimInDict;
extern const PrimitivePtr kPrimNotInDict;
extern const PrimitivePtr kPrimMixedPrecisionCast;
extern const PrimitivePtr kPrimIsConsant;
extern const PrimitivePtr kPrimEquivFormat;
extern const PrimitivePtr kPrimDebug;
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
inline const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax");
inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax");
inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad");
inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh");
inline const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad");
inline const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling");
inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad");
inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool");
inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad");
inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp");
inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");
inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad");
inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad");
inline const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad");
inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput");
inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter");
inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative");
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter =
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter");
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput =
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput");
inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad");
inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits =
std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits");
inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits =
std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits");
inline const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum");
inline const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum");
inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm");
inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad");
inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop");
inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask");
inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask");
inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");
inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp");
// Comm ops
extern const PrimitivePtr kPrimAllReduce;
extern const PrimitivePtr kPrimMirror;
extern const PrimitivePtr kPrimVirtualDiv;
extern const PrimitivePtr kPrimVirtualDataset;
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
// IndexedSlices
extern const PrimitivePtr kPrimMakeIndexedSlices;
extern const PrimitivePtr kPrimIndexedSlicesGetValues;
extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape;
inline const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices");
inline const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
inline const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
inline const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
inline const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices");
// SparseTensor
extern const PrimitivePtr kPrimMakeSparseTensor;
extern const PrimitivePtr kPrimSparseTensorGetValues;
extern const PrimitivePtr kPrimSparseTensorGetIndices;
extern const PrimitivePtr kPrimSparseTensorGetDenseShape;
inline const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor");
inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues");
inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll
const char SWITCH_UNROLL_FLAG[] = "unroll_flag";
@ -305,22 +191,6 @@ const char SWITCH_UNROLL_FLAG[] = "unroll_flag";
// will be sunk(i.e. not unrolled)
const int MAX_FOR_LOOP_COUNT = 600;
class DoSignaturePrimitive : public Primitive {
public:
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
: Primitive("S-Prim-" + name), function_(function) {}
~DoSignaturePrimitive() override = default;
MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive)
const ValuePtr function() const { return function_; }
private:
ValuePtr function_;
};
using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>;
class UnpackGraphPrimitive : public Primitive {
public:
explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args)

View File

@ -50,7 +50,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr &para, uint32_t
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) {
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
(void)cnode_set.emplace(cnode);
} else {
auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1);
@ -98,11 +98,12 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursi
return cnode_dist;
}
auto operator_info = cnode->GetUserData<OperatorInfo>();
MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode)
<< " operator_info: " << (cnode->operator_info() != nullptr);
<< " operator_info: " << (operator_info != nullptr);
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
auto cost = cnode->operator_info()->GetForwardMemoryCostFromCNode();
if (IsParallelCareNode(cnode) && (operator_info != nullptr)) {
auto cost = operator_info->GetForwardMemoryCostFromCNode();
MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost;
if (allreduce_graph_.NodeInGraph(cnode)) {

View File

@ -83,7 +83,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) {
}
auto para_ptr = node_ptr->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(para_ptr);
auto layout_ptr = para_ptr->tensor_layout();
auto layout_ptr = para_ptr->GetUserData<TensorLayout>();
if (layout_ptr == nullptr) {
MS_LOG(ERROR) << "layout_ptr is nullptr!";
return FAILED;

View File

@ -37,7 +37,7 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
for (auto para : graph_params) {
std::string name = std::static_pointer_cast<Parameter>(para)->name();
std::shared_ptr<parallel::TensorLayout> tensor_layout = std::static_pointer_cast<Parameter>(para)->tensor_layout();
auto tensor_layout = para->GetUserData<parallel::TensorLayout>();
if (tensor_layout == nullptr) {
MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name;
} else {
@ -70,7 +70,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr &graph) {
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto distributed_operation_info = cnode->operator_info();
auto distributed_operation_info = cnode->GetUserData<OperatorInfo>();
if (distributed_operation_info != nullptr) {
auto strategyPtr = distributed_operation_info->strategy();
if (strategyPtr != nullptr) {

View File

@ -163,6 +163,9 @@ class OperatorInfo {
const std::string &type() const { return type_; }
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
// Key for user data.
constexpr static char key[] = "OpInfo";
protected:
// needed by rec_parser
std::string type_;

View File

@ -435,7 +435,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
entire_costgraph->AddOperator(operator_info);
(void)cnode->set_operator_info(operator_info);
cnode->SetUserData<OperatorInfo>(operator_info);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
@ -501,7 +501,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
entire_costgraph->AddOperator(operator_info);
(void)cnode->set_operator_info(operator_info);
cnode->SetUserData<OperatorInfo>(operator_info);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
@ -520,7 +520,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
<< " does not match the Prim: " << prim->name();
}
(void)cnode->set_operator_info(current_op_ptr);
cnode->SetUserData<OperatorInfo>(current_op_ptr);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
@ -549,6 +549,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
size_t edge_count = 0;
auto node_op_info = cnode->GetUserData<OperatorInfo>();
for (size_t i = 1; i < inputs.size(); ++i) {
auto prev_cnode = inputs[i]->cast<CNodePtr>();
bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
@ -563,8 +565,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
(IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
while (bool_result) {
if (IsAutoParallelCareNode(prev_cnode)) {
std::string edge_name =
prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name();
auto prev_op_info = prev_cnode->GetUserData<OperatorInfo>();
std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name();
// If the edge between these two operators already has been added, then the edge will not be added again.
if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) {
break;
@ -577,22 +579,20 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
if (follow_strategy) {
// Redistribution in not allowed on the edge.
// Elementwise operators have the same strategy as their previous operators.
edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(),
output_index, i - 1, false, true);
edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, i - 1, false, true);
} else {
edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(),
output_index, i - 1, false);
edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, i - 1, false);
}
// Init costs for this edge
if (edge_ptr->InitEdgeCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Edge cost initialization failed";
}
cnode->operator_info()->AddPrevEdge(edge_ptr);
prev_cnode->operator_info()->AddSuccEdge(edge_ptr);
entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr);
MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and "
<< cnode->operator_info()->name();
node_op_info->AddPrevEdge(edge_ptr);
prev_op_info->AddSuccEdge(edge_ptr);
entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and "
<< node_op_info->name();
edge_count++;
break;
@ -633,7 +633,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
(IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
}
}
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name();
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
}
MS_LOG(INFO) << "Constructing edges for cost graph ends.";
@ -750,7 +750,8 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
for (auto &target : target_set) {
auto target_cnode = target.first->cast<CNodePtr>();
auto input_index = target.second;
(void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name());
(void)target_without_duplicate.insert(std::to_string(input_index) +
target_cnode->GetUserData<OperatorInfo>()->name());
}
if (target_without_duplicate.size() <= 1) {
continue;
@ -830,24 +831,24 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
auto target_cnode = target.first->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0));
auto input_index = target.second;
auto target_op_info = target_cnode->GetUserData<OperatorInfo>();
std::string edge_name =
std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name();
std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name();
// If the edge between these two operators already has been added, then the edge will not be added again.
if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) {
continue;
}
std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>(
edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true);
std::shared_ptr<Edge> edge_ptr =
std::make_shared<Edge>(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true);
if (edge_ptr->InitEdgeCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Edge cost initialization failed";
}
target_cnode->operator_info()->AddPrevEdge(edge_ptr);
target_op_info->AddPrevEdge(edge_ptr);
tmp_identity_ptr->AddSuccEdge(edge_ptr);
entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr);
entire_costgraph->AddEdge(tmp_identity_ptr, target_op_info, edge_ptr);
MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and "
<< target_cnode->operator_info()->name();
<< target_op_info->name();
add_identity_edge = true;
}
if (new_identity && add_identity_edge) {
@ -861,20 +862,13 @@ bool FindReshape(const CNodePtr &cnode) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) {
if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) {
return false;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info();
if (operator_info == nullptr) {
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
}
if (prim->name() != RESHAPE) {
return false;
}
return true;
return (prim->name() == RESHAPE);
}
// find previous node, then obtain its strategy_cost_ vector to get its layout vector.
@ -890,8 +884,9 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_
if (!IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
*pre_operator_info = cnode->operator_info();
auto node_op_info = cnode->GetUserData<OperatorInfo>();
if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) {
*pre_operator_info = node_op_info;
*out_index = 0;
return true;
}
@ -905,8 +900,9 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_
MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
}
CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) {
*pre_operator_info = pre_cnode->operator_info();
auto pre_op_info = pre_cnode->GetUserData<OperatorInfo>();
if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) {
*pre_operator_info = pre_op_info;
return true;
}
return false;
@ -945,14 +941,15 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) {
auto op_info = use_apply->GetUserData<OperatorInfo>();
if (IsParallelCareNode(use_apply) && (op_info != nullptr)) {
MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name();
*next_operator_info = use_apply->operator_info();
*next_operator_info = op_info;
*in_index = node_pair.second - 1;
return true;
}
MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
<< " " << (use_apply->operator_info() != nullptr);
<< " " << (op_info != nullptr);
if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) {
return true;
@ -973,8 +970,8 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
int32_t out_index = 0;
OperatorInfoPtr pre_operator_info;
std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
auto operator_info = cnode->GetUserData<OperatorInfo>();
if (pre_node->isa<Parameter>()) {
OperatorInfoPtr operator_info = cnode->operator_info();
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
reshape_info->SetCostForReshapeWithParameter();
pre_operator_info = reshape_info;
@ -995,7 +992,6 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
}
// set input_layout and output_layout for reshape.
// init reshape and set cost for each input_layout and output_layout.
OperatorInfoPtr operator_info = cnode->operator_info();
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
reshape_info->set_pre_operator_name(pre_operator_info->name());
reshape_info->set_pre_operator_index(out_index);

View File

@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
if (!IsParallelCareNode(node)) {
return nullptr;
}
OperatorInfoPtr distribute_operator = node->operator_info();
OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>();
if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr";
}
@ -415,7 +415,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) {
if (prim->name() == GET_NEXT) {
return true;
}
if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) {
if ((prim->name() == CAST) && !cnode->HasUserData<OperatorInfo>()) {
return false;
}
@ -452,7 +452,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(use_cnode) && (use_cnode->operator_info() != nullptr)) {
if (IsParallelCareNode(use_cnode) && use_cnode->HasUserData<OperatorInfo>()) {
Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution,
pre_node);
} else {
@ -465,7 +465,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(next_node);
OperatorInfoPtr op_info = next_node->operator_info();
OperatorInfoPtr op_info = next_node->GetUserData<OperatorInfo>();
MS_EXCEPTION_IF_NULL(op_info);
// If the shape of tensor is [] or [1], no need to split it.
@ -590,7 +590,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) {
void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
// step1:get graph manager distribute_operator
OperatorInfoPtr distribute_operator = node->operator_info();
OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>();
if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr";
}
@ -628,7 +628,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
(void)prim->SetAttrs(attrs);
}
if (index == replace_op.size() - 1) {
(void)replace_node->set_operator_info(node->operator_info());
replace_node->SetUserData<OperatorInfo>(node->GetUserData<OperatorInfo>());
}
replace_node->set_in_forward_flag(true);
replace_input[0]->set_scope(scope);
@ -708,7 +708,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) {
auto pre_cnode = pre_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
if (pre_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) {
pre_node = pre_cnode->input(1);
}
@ -1204,7 +1204,7 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) {
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) {
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
return node_pair;
} else if (FindParallelCareNode(node_pair.first).first != nullptr) {
return FindParallelCareNode(node_pair.first);
@ -1254,7 +1254,7 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i
MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
CNodePtr cnode = res.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
OperatorInfoPtr distribute_operator = cnode->operator_info();
OperatorInfoPtr distribute_operator = cnode->GetUserData<OperatorInfo>();
if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
}
@ -1277,7 +1277,7 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter_ptr);
parameter_ptr->set_tensor_layout(std::make_shared<TensorLayout>(tensor_layout));
parameter_ptr->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
}
void CoverSliceShape(const FuncGraphPtr &root) {
@ -1365,7 +1365,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
if (found_be_cloned_parameter) {
// set the shape and tensor layout for cloned parameter
cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout());
cloned_parameter->SetUserData<TensorLayout>(cloned_from_parameter->GetUserData<TensorLayout>());
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
@ -1464,7 +1464,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
(*operator_).set_outputs_dtype(cnode->Type());
(*operator_).set_cnode(cnode);
if (prim->name() == RESHAPE) {
(void)cnode->set_operator_info(operator_);
cnode->SetUserData<OperatorInfo>(operator_);
continue;
}
// load strategy checkpoint
@ -1499,7 +1499,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
if (operator_->Init(strategyPtr) == FAILED) {
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
}
(void)cnode->set_operator_info(operator_);
cnode->SetUserData<OperatorInfo>(operator_);
} else {
MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr";
}
@ -1542,13 +1542,13 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) {
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) {
if (IsParallelCareNode(use_apply) && use_apply->HasUserData<OperatorInfo>()) {
MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name();
auto layout = GetInputLayoutFromCNode(node_pair);
return std::make_shared<TensorLayout>(layout);
}
MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
<< " " << (use_apply->operator_info() != nullptr);
<< " " << use_apply->HasUserData<OperatorInfo>();
auto layout_ptr = FindNextLayout(use_apply);
if (layout_ptr) {
@ -1580,7 +1580,7 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
if (!IsValueNode<Primitive>(cnode->input(0))) {
return nullptr;
}
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index);
if (!layout_ptr) {
MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
@ -1624,7 +1624,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
if (!IsValueNode<Primitive>(cnode->input(0))) {
return nullptr;
}
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
if (!layout_ptr) {
MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
@ -1664,12 +1664,12 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
continue;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) {
if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) {
continue;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info();
OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>();
if (operator_info == nullptr) {
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
}
@ -1714,7 +1714,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
// return -> cast
if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
if (current_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) {
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
@ -1771,7 +1771,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) {
return ret;
}
OperatorInfoPtr operator_info = loss_cnode->operator_info();
OperatorInfoPtr operator_info = loss_cnode->GetUserData<OperatorInfo>();
MS_EXCEPTION_IF_NULL(operator_info);
TensorInfo loss_grad_tensor_info;
size_t op_output_size = operator_info->outputs_tensor_info().size();
@ -1809,7 +1809,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
if (sens_tensor_node->isa<Parameter>()) {
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString();
sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout));
sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
}
MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens";
return;
@ -1834,7 +1834,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
cloned_abstract->set_shape(parallel_shape);
sens_tensor_node->set_abstract(cloned_abstract);
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout));
sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
return;
}
MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now.";
@ -2125,7 +2125,7 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info();
OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>();
if (operator_info) {
if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
continue;

View File

@ -83,6 +83,9 @@ class TensorLayout {
TensorLayout SqueezeShape() const;
// Key for user data.
constexpr static char key[] = "TLayout";
private:
std::shared_ptr<TensorLayout> ExpandTensorShapeWithoutExtendDeviceArrangement(
const Arrangement &expanded_shape) const;

160
mindspore/core/base/core_ops.h Executable file
View File

@ -0,0 +1,160 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPERATOR_OPS_H_
#define MINDSPORE_CORE_OPERATOR_OPS_H_
#include <iostream>
#include <string>
#include <memory>
#include "ir/anf.h"
#include "ir/primitive.h"
namespace mindspore {
namespace prim {
// Maths
inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");
inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");
inline const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad");
inline const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean");
inline const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum");
inline const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll");
inline const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax");
inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin");
inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg");
inline const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub");
inline const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
inline const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum");
inline const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum");
inline const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square");
inline const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum");
inline const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd");
inline const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar");
inline const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd");
inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub");
inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow");
inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv");
inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal");
inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims");
// Statements
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");
inline const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch");
inline const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer");
inline const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign");
inline const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd");
inline const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub");
inline const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select");
inline const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call");
// Structures
inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple");
inline const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict");
inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg");
inline const PrimitivePtr kPrimMakeSlice = std::make_shared<Primitive>("make_slice");
inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
inline const PrimitivePtr kPrimTupleGetItem = std::make_shared<Primitive>("tuple_getitem");
inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem");
inline const PrimitivePtr kPrimArrayGetItem = std::make_shared<Primitive>("array_getitem");
inline const PrimitivePtr kPrimTupleSetItem = std::make_shared<Primitive>("tuple_setitem");
inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem");
inline const PrimitivePtr kPrimArraySetItem = std::make_shared<Primitive>("array_setitem");
inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem");
inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem");
inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append");
inline const PrimitivePtr kPrimGetAttr = std::make_shared<Primitive>("getattr");
inline const PrimitivePtr kPrimTupleLen = std::make_shared<Primitive>("tuple_len");
inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len");
inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");
inline const PrimitivePtr kPrimArrayLen = std::make_shared<Primitive>("array_len");
inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed");
inline const PrimitivePtr kPrimTileShape = std::make_shared<Primitive>("tile_shape");
inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape");
inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div");
inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array");
inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul");
inline const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared<Primitive>("generate_shape_index");
inline const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index");
inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal");
inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal");
inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range");
inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient");
inline const PrimitivePtr kPrimExtractKeywordArg = std::make_shared<Primitive>("extract_keyword_arg");
// Debug ops
inline const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
inline const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary");
inline const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
inline const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
inline const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug");
// Other miscellaneous
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
inline const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend");
inline const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("Partial");
inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");
inline const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem");
inline const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem");
inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward");
inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");
inline const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape");
inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
inline const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print");
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
inline const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend");
inline const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_");
inline const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
inline const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
inline const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
inline const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat");
class DoSignaturePrimitive : public Primitive {
public:
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
: Primitive("S-Prim-" + name), function_(function) {}
~DoSignaturePrimitive() override = default;
MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive)
const ValuePtr function() const { return function_; }
private:
ValuePtr function_;
};
using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>;
} // namespace prim
} // namespace mindspore
#endif // MINDSPORE_CORE_OPERATOR_OPS_H_

View File

@ -0,0 +1,52 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_USER_DATA_H_
#define MINDSPORE_CORE_USER_DATA_H_
#include <string>
#include <memory>
#include <map>
namespace mindspore {
class UserData {
public:
template <typename T>
void set(const std::string &key, const std::shared_ptr<T> &value) {
if (value == nullptr) {
data_.erase(key);
} else {
data_.insert_or_assign(key, value);
}
}
template <typename T>
std::shared_ptr<T> get(const std::string &key) const {
auto iter = data_.find(key);
if (iter == data_.end()) {
return nullptr;
}
return std::static_pointer_cast<T>(iter->second);
}
bool has(const std::string &key) const { return data_.find(key) != data_.end(); }
private:
std::map<std::string, std::shared_ptr<void>> data_;
};
} // namespace mindspore
#endif // MINDSPORE_CORE_USER_DATA_H_

View File

@ -26,7 +26,6 @@
#include "ir/func_graph.h"
#include "ir/primitive.h"
#include "utils/context/ms_context.h"
#include "frontend/operator/ops.h"
namespace mindspore {
// namespace to support intermediate representation definition

View File

@ -27,6 +27,7 @@
#include <utility>
#include "base/base.h"
#include "base/user_data.h"
#include "ir/kernel_info_dev.h"
#include "ir/scope.h"
#include "debug/info.h"
@ -41,12 +42,6 @@
// ANode: Atomic Node
// CNode: Complex Node
namespace mindspore {
namespace parallel {
class TensorLayout;
class OperatorInfo;
} // namespace parallel
using OperatorInfoPtr = std::shared_ptr<parallel::OperatorInfo>;
namespace abstract {
class BaseShape;
class AbstractBase;
@ -157,6 +152,31 @@ class AnfNode : public Base {
}
size_t seen_{0};
template <typename T>
void SetUserData(const std::string &key, const std::shared_ptr<T> &value) {
user_data_.set<T>(key, value);
}
template <typename T>
void SetUserData(const std::shared_ptr<T> &value) {
user_data_.set<T>(T::key, value);
}
template <typename T>
std::shared_ptr<T> GetUserData(const std::string &key) const {
return user_data_.get<T>(key);
}
template <typename T>
std::shared_ptr<T> GetUserData() const {
return user_data_.get<T>(T::key);
}
bool HasUserData(const std::string &key) const { return user_data_.has(key); }
template <typename T>
bool HasUserData() const { return user_data_.has(T::key); }
protected:
// Hold a weak ref to Graph as Graph also hold ref to AnfNode.
// Otherwise, func_graph_ and AnfNode will make a reference cycle.
@ -170,6 +190,7 @@ class AnfNode : public Base {
std::hash<const AnfNode *> hash_;
ScopePtr scope_;
KernelInfoDevicePtr kernel_info_;
UserData user_data_;
};
// CNode represents the complex node with a set of arguments.
@ -212,9 +233,6 @@ class CNode : public AnfNode {
std::string DebugString(int recursive_level = 1) const override;
std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); }
OperatorInfoPtr set_operator_info(const OperatorInfoPtr &operator_info);
OperatorInfoPtr operator_info() { return operator_info_; }
void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; }
bool in_forward_flag() const { return in_forward_flag_; }
@ -224,7 +242,6 @@ class CNode : public AnfNode {
std::vector<AnfNodePtr> inputs_;
VarPtr func_graph_as_var_;
bool stop_gradient_;
OperatorInfoPtr operator_info_ = nullptr;
bool in_forward_flag_ = false;
};
@ -244,7 +261,7 @@ class ANode : public AnfNode {
class Parameter : public ANode {
public:
explicit Parameter(const FuncGraphPtr &func_graph)
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), tensor_layout_(nullptr) {}
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr) {}
~Parameter() override = default;
MS_DECLARE_PARENT(Parameter, ANode);
@ -261,11 +278,6 @@ class Parameter : public ANode {
}
ParamValuePtr default_param() const { return default_param_; }
std::shared_ptr<parallel::TensorLayout> tensor_layout() const { return tensor_layout_; }
void set_tensor_layout(const std::shared_ptr<parallel::TensorLayout> &tensor_layout) {
tensor_layout_ = tensor_layout;
}
bool operator==(const AnfNode &other) const override {
if (!other.isa<Parameter>()) {
return false;
@ -281,7 +293,6 @@ class Parameter : public ANode {
std::string name_;
bool has_default_;
ParamValuePtr default_param_;
std::shared_ptr<parallel::TensorLayout> tensor_layout_;
};
using ParameterPtr = std::shared_ptr<Parameter>;

View File

@ -23,8 +23,7 @@
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "frontend/operator/ops.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "base/core_ops.h"
#include "debug/label.h"
namespace mindspore {
@ -37,18 +36,6 @@ std::string AnfNode::ToString() const {
return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
}
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
if (operator_info_ != nullptr) {
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
<< ", using the new one: " << operator_info->name();
auto old_ptr = operator_info_;
operator_info_ = operator_info;
return old_ptr;
}
operator_info_ = operator_info;
return nullptr;
}
std::string CNode::fullname_with_scope() {
// if full name is set, return its name immediately
if (!fullname_with_scope_.empty()) {

View File

@ -24,7 +24,6 @@
#include "debug/trace.h"
#include "ir/manager.h"
#include "frontend/operator/ops.h"
#include "utils/ordered_set.h"
#include "utils/convert_utils_base.h"

View File

@ -20,7 +20,7 @@
#include "ir/manager.h"
#include "ir/param_value.h"
#include "frontend/operator/ops.h"
#include "base/core_ops.h"
#include "utils/convert_utils_base.h"
#include "utils/log_adapter.h"
#include "utils/profile.h"

View File

@ -22,7 +22,7 @@
#include "ir/manager.h"
#include "ir/func_graph_cloner.h"
#include "frontend/operator/ops.h"
#include "base/core_ops.h"
#include "utils/ordered_set.h"
#include "abstract/abstract_value.h"
#include "debug/anf_ir_dump.h"

View File

@ -26,7 +26,7 @@
#include "ir/func_graph.h"
#include "utils/profile.h"
#include "utils/convert_utils_base.h"
#include "frontend/operator/ops.h"
#include "base/core_ops.h"
namespace mindspore {

View File

@ -17,10 +17,8 @@
*/
#include "ir/meta_func_graph.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "pipeline/jit/static_analysis/abstract_function.h"
#include "base/core_ops.h"
#include "utils/context/ms_context.h"
#include "frontend/operator/ops.h"
// namespace to support intermediate representation definition
namespace mindspore {

View File

@ -22,9 +22,9 @@
#include <tuple>
#include <vector>
#include "frontend/operator/ops.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/anf.h"
#include "ir/optimizer_caller.h"
#include "base/core_ops.h"
namespace mindspore {
///

View File

@ -25,7 +25,6 @@
#include "ir/dtype/type.h"
#include "abstract/abstract_value.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "utils/base_ref_extends.h"
namespace mindspore {

View File

@ -18,7 +18,6 @@
#include <mutex>
#include <utility>
#include "ir/signature.h"
#include "frontend/operator/ops.h"
#include "./common.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/parse/data_converter.h"

View File

@ -28,7 +28,6 @@
#include <type_traits>
#include <typeinfo>
#include "runtime/device/device_address.h"
#include "abstract/abstract_value.h"
namespace mindspore {

View File

@ -153,7 +153,7 @@ TEST_F(TestStepAutoParallel, test_create_op_instance) {
StrategyPtr strategyPtr;
std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape);
node->set_operator_info(matmul_info);
node->SetUserData<OperatorInfo>(matmul_info);
std::string name_expect = "MatMulInfo00";
std::string name_test = matmul_info->name();
ASSERT_EQ(name_expect, name_test);

View File

@ -525,8 +525,8 @@ TEST_F(TestStepParallel, GetTensorInLayout) {
std::vector<Shapes> shape = {inputs_shape, outputs_shape};
OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape);
matmul_info->Init(strategyPtr);
node->set_operator_info(matmul_info);
OperatorInfoPtr distribute_operator_pre = node->operator_info();
node->SetUserData<OperatorInfo>(matmul_info);
OperatorInfoPtr distribute_operator_pre = node->GetUserData<OperatorInfo>();
TensorLayout tensorlayout_e;
std::vector<int32_t> array = {64, 64};
TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre);