forked from mindspore-Ecosystem/mindspore
!3129 Decouple ir from frontend
Merge pull request !3129 from hewei/decouple_ir_frontend
This commit is contained in:
commit
a2bf5a322e
|
@ -27,6 +27,7 @@
|
|||
#include "runtime/device/kernel_info.h"
|
||||
#include "utils/graph_utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
const std::string ToShortString(const TypeId &typeId) {
|
||||
|
@ -266,7 +267,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo
|
|||
return;
|
||||
}
|
||||
|
||||
auto operator_info = node->operator_info();
|
||||
auto operator_info = node->GetUserData<parallel::OperatorInfo>();
|
||||
if (operator_info == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -437,7 +437,7 @@ static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) {
|
|||
if (graph_obj == nullptr || node == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto distributed_operation_info = node->operator_info();
|
||||
auto distributed_operation_info = node->GetUserData<parallel::OperatorInfo>();
|
||||
if (distributed_operation_info != nullptr) {
|
||||
auto strategyPtr = distributed_operation_info->strategy();
|
||||
if (strategyPtr != nullptr) {
|
||||
|
|
|
@ -1,293 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/operator/ops.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support primitive operators
|
||||
namespace prim {
|
||||
// Arithmetic
|
||||
const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add");
|
||||
const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub");
|
||||
const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul");
|
||||
const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div");
|
||||
const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>("scalar_floordiv");
|
||||
const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod");
|
||||
const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow");
|
||||
const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc");
|
||||
const PrimitivePtr kPrimScalarFloor = std::make_shared<Primitive>("scalar_floor");
|
||||
const PrimitivePtr kPrimScalarUadd = std::make_shared<Primitive>("scalar_uadd");
|
||||
const PrimitivePtr kPrimScalarUsub = std::make_shared<Primitive>("scalar_usub");
|
||||
const PrimitivePtr kPrimScalarExp = std::make_shared<Primitive>("scalar_exp");
|
||||
const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_log");
|
||||
const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin");
|
||||
const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos");
|
||||
const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan");
|
||||
|
||||
// Comparisons
|
||||
const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq");
|
||||
const PrimitivePtr kPrimScalarLt = std::make_shared<Primitive>("scalar_lt");
|
||||
const PrimitivePtr kPrimScalarGt = std::make_shared<Primitive>("scalar_gt");
|
||||
const PrimitivePtr kPrimScalarNe = std::make_shared<Primitive>("scalar_ne");
|
||||
const PrimitivePtr kPrimScalarLe = std::make_shared<Primitive>("scalar_le");
|
||||
const PrimitivePtr kPrimScalarGe = std::make_shared<Primitive>("scalar_ge");
|
||||
const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not");
|
||||
const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and");
|
||||
const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or");
|
||||
const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq");
|
||||
const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater");
|
||||
const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual");
|
||||
const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
|
||||
const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
|
||||
const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
|
||||
const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual");
|
||||
|
||||
// Type introspection
|
||||
const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
|
||||
const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
|
||||
|
||||
// Statements
|
||||
const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch");
|
||||
const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer");
|
||||
const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");
|
||||
const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign");
|
||||
const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd");
|
||||
const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub");
|
||||
const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select");
|
||||
const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call");
|
||||
|
||||
const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
|
||||
const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot");
|
||||
const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col");
|
||||
const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im");
|
||||
const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1");
|
||||
const PrimitivePtr kPrimCol2ImV1 = std::make_shared<Primitive>("col2im_v1");
|
||||
|
||||
const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve");
|
||||
const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
|
||||
const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
|
||||
const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
|
||||
|
||||
const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGoto");
|
||||
const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch");
|
||||
const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet");
|
||||
|
||||
// Structure
|
||||
const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
|
||||
const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
|
||||
const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple");
|
||||
const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
|
||||
const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict");
|
||||
const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg");
|
||||
const PrimitivePtr kPrimExtractKeywordArg = std::make_shared<Primitive>("extract_keyword_arg");
|
||||
const PrimitivePtr kPrimMakeSlice = std::make_shared<Primitive>("make_slice");
|
||||
const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
|
||||
const PrimitivePtr kPrimTupleGetItem = std::make_shared<Primitive>("tuple_getitem");
|
||||
const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem");
|
||||
const PrimitivePtr kPrimArrayGetItem = std::make_shared<Primitive>("array_getitem");
|
||||
const PrimitivePtr kPrimTupleSetItem = std::make_shared<Primitive>("tuple_setitem");
|
||||
const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem");
|
||||
const PrimitivePtr kPrimArraySetItem = std::make_shared<Primitive>("array_setitem");
|
||||
const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem");
|
||||
const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem");
|
||||
const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append");
|
||||
const PrimitivePtr kPrimGetAttr = std::make_shared<Primitive>("getattr");
|
||||
const PrimitivePtr kPrimTupleLen = std::make_shared<Primitive>("tuple_len");
|
||||
const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len");
|
||||
const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");
|
||||
const PrimitivePtr kPrimArrayLen = std::make_shared<Primitive>("array_len");
|
||||
const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
|
||||
const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
|
||||
const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed");
|
||||
|
||||
const PrimitivePtr kPrimTileShape = std::make_shared<Primitive>("tile_shape");
|
||||
const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape");
|
||||
const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div");
|
||||
const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array");
|
||||
const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul");
|
||||
const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared<Primitive>("generate_shape_index");
|
||||
const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index");
|
||||
const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal");
|
||||
const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal");
|
||||
const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range");
|
||||
const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient");
|
||||
|
||||
// Arrays
|
||||
const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array");
|
||||
const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar");
|
||||
const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape");
|
||||
const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map");
|
||||
const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce");
|
||||
const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
|
||||
const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
|
||||
const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
|
||||
const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
|
||||
const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
|
||||
const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
|
||||
const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup");
|
||||
const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad");
|
||||
const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
|
||||
const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax");
|
||||
const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack");
|
||||
const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum");
|
||||
const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
|
||||
const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");
|
||||
const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
|
||||
const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
|
||||
const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
|
||||
const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData");
|
||||
const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask");
|
||||
const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");
|
||||
const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue");
|
||||
|
||||
// Maths
|
||||
const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");
|
||||
const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
|
||||
const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
|
||||
const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");
|
||||
const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad");
|
||||
const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean");
|
||||
const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum");
|
||||
const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll");
|
||||
const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax");
|
||||
const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin");
|
||||
const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg");
|
||||
const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub");
|
||||
const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
|
||||
const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum");
|
||||
const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum");
|
||||
const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square");
|
||||
const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum");
|
||||
const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd");
|
||||
const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar");
|
||||
const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd");
|
||||
const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub");
|
||||
const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow");
|
||||
const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv");
|
||||
const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
|
||||
const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal");
|
||||
const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims");
|
||||
|
||||
// NN
|
||||
const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
|
||||
const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax");
|
||||
const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax");
|
||||
const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad");
|
||||
const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh");
|
||||
const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad");
|
||||
const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling");
|
||||
const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad");
|
||||
const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool");
|
||||
const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad");
|
||||
const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp");
|
||||
const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");
|
||||
const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
|
||||
const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
|
||||
const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad");
|
||||
const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
|
||||
const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad");
|
||||
const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad");
|
||||
const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput");
|
||||
const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter");
|
||||
const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative");
|
||||
const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter =
|
||||
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter");
|
||||
const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput =
|
||||
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput");
|
||||
const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad");
|
||||
const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits");
|
||||
const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits =
|
||||
std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits");
|
||||
const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum");
|
||||
const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum");
|
||||
const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm");
|
||||
const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad");
|
||||
const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop");
|
||||
const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
|
||||
const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask");
|
||||
const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask");
|
||||
const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
|
||||
const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
|
||||
const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");
|
||||
const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
|
||||
const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
|
||||
const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
|
||||
const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
|
||||
const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
|
||||
const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
|
||||
const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
|
||||
const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp");
|
||||
|
||||
// Other miscellaneous
|
||||
const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");
|
||||
const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("Partial");
|
||||
const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
|
||||
const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem");
|
||||
const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem");
|
||||
const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
|
||||
const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
|
||||
const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
|
||||
const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
|
||||
const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
|
||||
const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
|
||||
const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward");
|
||||
const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");
|
||||
const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape");
|
||||
const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
|
||||
const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print");
|
||||
|
||||
const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
|
||||
const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend");
|
||||
const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
|
||||
|
||||
const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
|
||||
const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend");
|
||||
const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_");
|
||||
const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
|
||||
const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
|
||||
const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
|
||||
const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
|
||||
const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
|
||||
const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat");
|
||||
|
||||
// Comm ops
|
||||
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||
const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
|
||||
const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
||||
const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
|
||||
|
||||
// Debug ops
|
||||
const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
|
||||
const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary");
|
||||
const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
|
||||
const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
|
||||
const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug");
|
||||
|
||||
// IndexedSlices
|
||||
const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices");
|
||||
const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
|
||||
const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
|
||||
const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
|
||||
|
||||
// SparseTensor
|
||||
const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor");
|
||||
const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues");
|
||||
const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
|
||||
const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
|
@ -22,6 +22,7 @@
|
|||
#include <memory>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support primitive operators
|
||||
|
@ -31,273 +32,158 @@ ValuePtr GetPythonOps(const std::string &op_name,
|
|||
bool use_signature = false);
|
||||
|
||||
// Arithmetic
|
||||
extern const PrimitivePtr kPrimScalarAdd;
|
||||
extern const PrimitivePtr kPrimScalarSub;
|
||||
extern const PrimitivePtr kPrimScalarMul;
|
||||
extern const PrimitivePtr kPrimScalarDiv;
|
||||
extern const PrimitivePtr kPrimScalarFloordiv;
|
||||
extern const PrimitivePtr kPrimScalarMod;
|
||||
extern const PrimitivePtr kPrimScalarPow;
|
||||
extern const PrimitivePtr kPrimScalarTrunc;
|
||||
extern const PrimitivePtr kPrimScalarFloor;
|
||||
extern const PrimitivePtr kPrimScalarUadd;
|
||||
extern const PrimitivePtr kPrimScalarUsub;
|
||||
extern const PrimitivePtr kPrimScalarExp;
|
||||
extern const PrimitivePtr kPrimScalarLog;
|
||||
extern const PrimitivePtr kPrimScalarSin;
|
||||
extern const PrimitivePtr kPrimScalarCos;
|
||||
extern const PrimitivePtr kPrimScalarTan;
|
||||
inline const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add");
|
||||
inline const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub");
|
||||
inline const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul");
|
||||
inline const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div");
|
||||
inline const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>("scalar_floordiv");
|
||||
inline const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod");
|
||||
inline const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow");
|
||||
inline const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc");
|
||||
inline const PrimitivePtr kPrimScalarFloor = std::make_shared<Primitive>("scalar_floor");
|
||||
inline const PrimitivePtr kPrimScalarUadd = std::make_shared<Primitive>("scalar_uadd");
|
||||
inline const PrimitivePtr kPrimScalarUsub = std::make_shared<Primitive>("scalar_usub");
|
||||
inline const PrimitivePtr kPrimScalarExp = std::make_shared<Primitive>("scalar_exp");
|
||||
inline const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_log");
|
||||
inline const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin");
|
||||
inline const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos");
|
||||
inline const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan");
|
||||
|
||||
// Comparisons
|
||||
extern const PrimitivePtr kPrimScalarEq;
|
||||
extern const PrimitivePtr kPrimScalarLt;
|
||||
extern const PrimitivePtr kPrimScalarGt;
|
||||
extern const PrimitivePtr kPrimScalarNe;
|
||||
extern const PrimitivePtr kPrimScalarLe;
|
||||
extern const PrimitivePtr kPrimScalarGe;
|
||||
extern const PrimitivePtr kPrimBoolNot;
|
||||
extern const PrimitivePtr kPrimBoolAnd;
|
||||
extern const PrimitivePtr kPrimBoolOr;
|
||||
extern const PrimitivePtr kPrimBoolEq;
|
||||
extern const PrimitivePtr kPrimGreater;
|
||||
extern const PrimitivePtr kPrimGreaterEqual;
|
||||
extern const PrimitivePtr kPrimLess;
|
||||
extern const PrimitivePtr kPrimLessEqual;
|
||||
extern const PrimitivePtr kPrimEqual;
|
||||
extern const PrimitivePtr kPrimNotEqual;
|
||||
inline const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq");
|
||||
inline const PrimitivePtr kPrimScalarLt = std::make_shared<Primitive>("scalar_lt");
|
||||
inline const PrimitivePtr kPrimScalarGt = std::make_shared<Primitive>("scalar_gt");
|
||||
inline const PrimitivePtr kPrimScalarNe = std::make_shared<Primitive>("scalar_ne");
|
||||
inline const PrimitivePtr kPrimScalarLe = std::make_shared<Primitive>("scalar_le");
|
||||
inline const PrimitivePtr kPrimScalarGe = std::make_shared<Primitive>("scalar_ge");
|
||||
inline const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not");
|
||||
inline const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and");
|
||||
inline const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or");
|
||||
inline const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq");
|
||||
inline const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater");
|
||||
inline const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual");
|
||||
inline const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
|
||||
inline const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
|
||||
inline const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
|
||||
inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual");
|
||||
|
||||
// Type introspection
|
||||
extern const PrimitivePtr kPrimTypeOf;
|
||||
extern const PrimitivePtr kPrimHasType;
|
||||
inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
|
||||
inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
|
||||
|
||||
// Statements
|
||||
extern const PrimitivePtr kPrimSwitch;
|
||||
extern const PrimitivePtr kPrimSwitchLayer;
|
||||
extern const PrimitivePtr kPrimReturn;
|
||||
extern const PrimitivePtr kPrimAssign;
|
||||
extern const PrimitivePtr kPrimAssignAdd;
|
||||
extern const PrimitivePtr kPrimAssignSub;
|
||||
extern const PrimitivePtr kPrimSelect;
|
||||
extern const PrimitivePtr kPrimCall;
|
||||
inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
|
||||
inline const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot");
|
||||
inline const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col");
|
||||
inline const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im");
|
||||
inline const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1");
|
||||
inline const PrimitivePtr kPrimCol2ImV1 = std::make_shared<Primitive>("col2im_v1");
|
||||
|
||||
extern const PrimitivePtr kPrimDistribute;
|
||||
extern const PrimitivePtr kPrimDot;
|
||||
extern const PrimitivePtr kPrimIm2Col;
|
||||
extern const PrimitivePtr kPrimCol2Im;
|
||||
extern const PrimitivePtr kPrimIm2ColV1;
|
||||
extern const PrimitivePtr kPrimCol2ImV1;
|
||||
inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve");
|
||||
inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
|
||||
inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
|
||||
inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
|
||||
|
||||
extern const PrimitivePtr kPrimResolve;
|
||||
extern const PrimitivePtr kPrimEmbed;
|
||||
extern const PrimitivePtr kPrimRefToEmbed;
|
||||
extern const PrimitivePtr kPrimCreateInstance;
|
||||
|
||||
extern const PrimitivePtr kPrimLabelGoto;
|
||||
extern const PrimitivePtr kPrimLabelSwitch;
|
||||
extern const PrimitivePtr kPrimLabelSet;
|
||||
|
||||
// Structure
|
||||
extern const PrimitivePtr kPrimStringEqual;
|
||||
extern const PrimitivePtr kPrimStringConcat;
|
||||
extern const PrimitivePtr kPrimMakeTuple;
|
||||
extern const PrimitivePtr kPrimMakeList;
|
||||
extern const PrimitivePtr kPrimMakeDict;
|
||||
extern const PrimitivePtr kPrimMakeKeywordArg;
|
||||
extern const PrimitivePtr kPrimExtractKeywordArg;
|
||||
extern const PrimitivePtr kPrimMakeSlice;
|
||||
extern const PrimitivePtr kPrimMakeRecord;
|
||||
extern const PrimitivePtr kPrimTupleGetItem;
|
||||
extern const PrimitivePtr kPrimListGetItem;
|
||||
extern const PrimitivePtr kPrimArrayGetItem;
|
||||
extern const PrimitivePtr kPrimTupleSetItem;
|
||||
extern const PrimitivePtr kPrimListSetItem;
|
||||
extern const PrimitivePtr kPrimArraySetItem;
|
||||
extern const PrimitivePtr kPrimDictGetItem;
|
||||
extern const PrimitivePtr kPrimDictSetItem;
|
||||
extern const PrimitivePtr kPrimListAppend;
|
||||
extern const PrimitivePtr kPrimGetAttr;
|
||||
extern const PrimitivePtr kPrimTupleLen;
|
||||
extern const PrimitivePtr kPrimDictLen;
|
||||
extern const PrimitivePtr kPrimListLen;
|
||||
extern const PrimitivePtr kPrimArrayLen;
|
||||
extern const PrimitivePtr kPrimListMap;
|
||||
extern const PrimitivePtr kPrimListReduce;
|
||||
extern const PrimitivePtr kPrimTupleReversed;
|
||||
extern const PrimitivePtr kPrimTileShape;
|
||||
extern const PrimitivePtr kPrimReducedShape;
|
||||
extern const PrimitivePtr kPrimTupleDiv;
|
||||
extern const PrimitivePtr kPrimTupleToArray;
|
||||
extern const PrimitivePtr kPrimShapeMul;
|
||||
extern const PrimitivePtr kPrimGenerateShapeIndex;
|
||||
extern const PrimitivePtr kPrimGenerateInverseIndex;
|
||||
extern const PrimitivePtr kPrimTupleEqual;
|
||||
extern const PrimitivePtr kPrimListEqual;
|
||||
extern const PrimitivePtr kPrimMakeRange;
|
||||
extern const PrimitivePtr kPrimStopGradient;
|
||||
inline const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGoto");
|
||||
inline const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch");
|
||||
inline const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet");
|
||||
|
||||
// Arrays
|
||||
extern const PrimitivePtr kPrimScalarToArray;
|
||||
extern const PrimitivePtr kPrimArrayToScalar;
|
||||
extern const PrimitivePtr kPrimBroadcastShape;
|
||||
extern const PrimitivePtr kPrimArrayMap;
|
||||
extern const PrimitivePtr kPrimArrayReduce;
|
||||
extern const PrimitivePtr kPrimShape;
|
||||
extern const PrimitivePtr kPrimCast;
|
||||
extern const PrimitivePtr kPrimConcat;
|
||||
extern const PrimitivePtr kPrimSqueeze;
|
||||
extern const PrimitivePtr kPrimTranspose;
|
||||
extern const PrimitivePtr kPrimGatherV2;
|
||||
extern const PrimitivePtr kPrimEmbeddingLookup;
|
||||
extern const PrimitivePtr kPrimEmbeddingLookupCommGrad;
|
||||
extern const PrimitivePtr kPrimSize;
|
||||
extern const PrimitivePtr kPrimArgMax;
|
||||
extern const PrimitivePtr kPrimPack;
|
||||
extern const PrimitivePtr kPrimUnpack;
|
||||
extern const PrimitivePtr kPrimUnsortedSegmentMin;
|
||||
extern const PrimitivePtr kPrimUnsortedSegmentSum;
|
||||
extern const PrimitivePtr kPrimConcatOffset;
|
||||
extern const PrimitivePtr kPrimReshape;
|
||||
extern const PrimitivePtr kPrimTile;
|
||||
extern const PrimitivePtr kPrimAddN;
|
||||
extern const PrimitivePtr KPrimTransData;
|
||||
extern const PrimitivePtr kPrimNMSWithMask;
|
||||
extern const PrimitivePtr kPrimPad;
|
||||
extern const PrimitivePtr kPrimArgMaxWithValue;
|
||||
extern const PrimitivePtr kPrimRealDiv;
|
||||
extern const PrimitivePtr kPrimSqrt;
|
||||
extern const PrimitivePtr kPrimReciprocal;
|
||||
extern const PrimitivePtr kPrimExpandDims;
|
||||
|
||||
// Maths
|
||||
extern const PrimitivePtr kPrimTensorAdd;
|
||||
extern const PrimitivePtr kPrimMatMul;
|
||||
extern const PrimitivePtr kPrimBatchMatMul;
|
||||
extern const PrimitivePtr kPrimMaximumGrad;
|
||||
extern const PrimitivePtr kPrimMinimumGrad;
|
||||
extern const PrimitivePtr kPrimReduceMean;
|
||||
extern const PrimitivePtr kPrimReduceSum;
|
||||
extern const PrimitivePtr kPrimReduceAll;
|
||||
extern const PrimitivePtr kPrimReduceMax;
|
||||
extern const PrimitivePtr kPrimReduceMin;
|
||||
extern const PrimitivePtr kPrimNeg;
|
||||
extern const PrimitivePtr kPrimSub;
|
||||
extern const PrimitivePtr kPrimMul;
|
||||
extern const PrimitivePtr kPrimRealDiv;
|
||||
extern const PrimitivePtr kPrimMinimum;
|
||||
extern const PrimitivePtr kPrimMaximum;
|
||||
extern const PrimitivePtr kPrimSquare;
|
||||
extern const PrimitivePtr kPrimSqrt;
|
||||
extern const PrimitivePtr kPrimEqual;
|
||||
extern const PrimitivePtr kPrimLess;
|
||||
extern const PrimitivePtr kPrimLessEqual;
|
||||
extern const PrimitivePtr kPrimCumSum;
|
||||
extern const PrimitivePtr kPrimCumProd;
|
||||
extern const PrimitivePtr kPrimSubscalar;
|
||||
extern const PrimitivePtr kPrimInplaceAdd;
|
||||
extern const PrimitivePtr kPrimInplaceSub;
|
||||
extern const PrimitivePtr kPrimPow;
|
||||
inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array");
|
||||
inline const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar");
|
||||
inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape");
|
||||
inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map");
|
||||
inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce");
|
||||
inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
|
||||
inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
|
||||
inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
|
||||
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
|
||||
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
|
||||
inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
|
||||
inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup");
|
||||
inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad");
|
||||
inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
|
||||
inline const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax");
|
||||
inline const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack");
|
||||
inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum");
|
||||
inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
|
||||
inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");
|
||||
inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
|
||||
inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
|
||||
inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
|
||||
inline const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData");
|
||||
inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask");
|
||||
inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");
|
||||
inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue");
|
||||
|
||||
// NN
|
||||
extern const PrimitivePtr kPrimFlatten;
|
||||
extern const PrimitivePtr kPrimSoftmax;
|
||||
extern const PrimitivePtr kPrimLogSoftmax;
|
||||
extern const PrimitivePtr kPrimLogSoftmaxGrad;
|
||||
extern const PrimitivePtr kPrimApplyCenteredRMSProp;
|
||||
extern const PrimitivePtr kPrimTanh;
|
||||
extern const PrimitivePtr kPrimTanhGrad;
|
||||
extern const PrimitivePtr kPrimPooling;
|
||||
extern const PrimitivePtr kPrimPoolingGrad;
|
||||
extern const PrimitivePtr kPrimFusedBatchNorm;
|
||||
extern const PrimitivePtr kPrimBatchNorm;
|
||||
extern const PrimitivePtr kPrimBatchNormGrad;
|
||||
extern const PrimitivePtr kPrimConv2D;
|
||||
extern const PrimitivePtr kPrimMaxPool;
|
||||
extern const PrimitivePtr kPrimMaxPoolGrad;
|
||||
extern const PrimitivePtr kPrimAvgPoolGrad;
|
||||
extern const PrimitivePtr kPrimFusedBatchNormGrad;
|
||||
extern const PrimitivePtr kPrimReluGrad;
|
||||
extern const PrimitivePtr kPrimConv2DBackpropInput;
|
||||
extern const PrimitivePtr kPrimConv2DBackpropFilter;
|
||||
extern const PrimitivePtr kPrimDepthwiseConv2dNative;
|
||||
extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter;
|
||||
extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput;
|
||||
|
||||
extern const PrimitivePtr kPrimBiasAddGrad;
|
||||
extern const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits;
|
||||
extern const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits;
|
||||
extern const PrimitivePtr kPrimMomentum;
|
||||
extern const PrimitivePtr kPrimApplyMomentum;
|
||||
extern const PrimitivePtr kPrimLayerNorm;
|
||||
extern const PrimitivePtr kPrimLayerNormGrad;
|
||||
extern const PrimitivePtr kPrimLayerNormXBackprop;
|
||||
extern const PrimitivePtr kPrimLayerNormBetaGammaBackprop;
|
||||
extern const PrimitivePtr kPrimDropoutGenMask;
|
||||
extern const PrimitivePtr kPrimDropoutDoMask;
|
||||
extern const PrimitivePtr kPrimOneHot;
|
||||
extern const PrimitivePtr kPrimGelu;
|
||||
extern const PrimitivePtr kPrimGeluGrad;
|
||||
extern const PrimitivePtr kPrimRelu;
|
||||
extern const PrimitivePtr kPrimReluV2;
|
||||
extern const PrimitivePtr kPrimActivation;
|
||||
extern const PrimitivePtr kPrimZerosLike;
|
||||
extern const PrimitivePtr kPrimFakeBprop;
|
||||
extern const PrimitivePtr kPrimBpropCut;
|
||||
extern const PrimitivePtr kPrimFakeQuantPerLayer;
|
||||
extern const PrimitivePtr kPrimFakeQuantPerChannel;
|
||||
extern const PrimitivePtr kPrimApplyRMSProp;
|
||||
|
||||
// Other Miscellaneous
|
||||
extern const PrimitivePtr kPrimIdentity;
|
||||
extern const PrimitivePtr kPrimPartial;
|
||||
extern const PrimitivePtr kPrimJ;
|
||||
extern const PrimitivePtr kPrimEnvSetItem;
|
||||
extern const PrimitivePtr kPrimEnvGetItem;
|
||||
extern const PrimitivePtr kPrimEnvAdd;
|
||||
extern const PrimitivePtr kPrimMakeRefKey;
|
||||
extern const PrimitivePtr kPrimMakeRef;
|
||||
extern const PrimitivePtr kPrimGetRefKey;
|
||||
extern const PrimitivePtr kPrimGetRefValue;
|
||||
extern const PrimitivePtr kPrimGetRefOrigin;
|
||||
extern const PrimitivePtr kPrimInsertGradientOf;
|
||||
extern const PrimitivePtr kPrimHookBackward;
|
||||
extern const PrimitivePtr kPrimPrintShapeType;
|
||||
extern const PrimitivePtr kPrimPrint;
|
||||
extern const PrimitivePtr kPrimSameTypeShape;
|
||||
extern const PrimitivePtr kPrimCheckBprop;
|
||||
extern const PrimitivePtr kPrimDepend;
|
||||
extern const PrimitivePtr kPrimStateSetItem;
|
||||
extern const PrimitivePtr kPrimScalarSummary;
|
||||
extern const PrimitivePtr kPrimImageSummary;
|
||||
extern const PrimitivePtr kPrimTensorSummary;
|
||||
extern const PrimitivePtr kPrimHistogramSummary;
|
||||
extern const PrimitivePtr kPrimBroadcastGradientArgs;
|
||||
extern const PrimitivePtr kPrimControlDepend;
|
||||
extern const PrimitivePtr kPrimIs_;
|
||||
extern const PrimitivePtr kPrimIsNot;
|
||||
extern const PrimitivePtr kPrimInDict;
|
||||
extern const PrimitivePtr kPrimNotInDict;
|
||||
extern const PrimitivePtr kPrimMixedPrecisionCast;
|
||||
extern const PrimitivePtr kPrimIsConsant;
|
||||
extern const PrimitivePtr kPrimEquivFormat;
|
||||
extern const PrimitivePtr kPrimDebug;
|
||||
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
|
||||
inline const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax");
|
||||
inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax");
|
||||
inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad");
|
||||
inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh");
|
||||
inline const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad");
|
||||
inline const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling");
|
||||
inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad");
|
||||
inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool");
|
||||
inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad");
|
||||
inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp");
|
||||
inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");
|
||||
inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
|
||||
inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
|
||||
inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad");
|
||||
inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
|
||||
inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad");
|
||||
inline const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad");
|
||||
inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput");
|
||||
inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter");
|
||||
inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative");
|
||||
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter =
|
||||
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter");
|
||||
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput =
|
||||
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput");
|
||||
inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad");
|
||||
inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits =
|
||||
std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits");
|
||||
inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits =
|
||||
std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits");
|
||||
inline const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum");
|
||||
inline const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum");
|
||||
inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm");
|
||||
inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad");
|
||||
inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop");
|
||||
inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
|
||||
inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask");
|
||||
inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask");
|
||||
inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
|
||||
inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
|
||||
inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");
|
||||
inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
|
||||
inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
|
||||
inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
|
||||
inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
|
||||
inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
|
||||
inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
|
||||
inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
|
||||
inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp");
|
||||
|
||||
// Comm ops
|
||||
extern const PrimitivePtr kPrimAllReduce;
|
||||
extern const PrimitivePtr kPrimMirror;
|
||||
extern const PrimitivePtr kPrimVirtualDiv;
|
||||
extern const PrimitivePtr kPrimVirtualDataset;
|
||||
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
|
||||
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
||||
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
|
||||
|
||||
// IndexedSlices
|
||||
extern const PrimitivePtr kPrimMakeIndexedSlices;
|
||||
extern const PrimitivePtr kPrimIndexedSlicesGetValues;
|
||||
extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
|
||||
extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape;
|
||||
inline const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices");
|
||||
inline const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
|
||||
inline const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
|
||||
inline const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
|
||||
inline const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices");
|
||||
|
||||
// SparseTensor
|
||||
extern const PrimitivePtr kPrimMakeSparseTensor;
|
||||
extern const PrimitivePtr kPrimSparseTensorGetValues;
|
||||
extern const PrimitivePtr kPrimSparseTensorGetIndices;
|
||||
extern const PrimitivePtr kPrimSparseTensorGetDenseShape;
|
||||
inline const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor");
|
||||
inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues");
|
||||
inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
|
||||
inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
|
||||
|
||||
// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll
|
||||
const char SWITCH_UNROLL_FLAG[] = "unroll_flag";
|
||||
|
@ -305,22 +191,6 @@ const char SWITCH_UNROLL_FLAG[] = "unroll_flag";
|
|||
// will be sunk(i.e. not unrolled)
|
||||
const int MAX_FOR_LOOP_COUNT = 600;
|
||||
|
||||
class DoSignaturePrimitive : public Primitive {
|
||||
public:
|
||||
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
|
||||
: Primitive("S-Prim-" + name), function_(function) {}
|
||||
|
||||
~DoSignaturePrimitive() override = default;
|
||||
|
||||
MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive)
|
||||
|
||||
const ValuePtr function() const { return function_; }
|
||||
|
||||
private:
|
||||
ValuePtr function_;
|
||||
};
|
||||
using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>;
|
||||
|
||||
class UnpackGraphPrimitive : public Primitive {
|
||||
public:
|
||||
explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args)
|
||||
|
|
|
@ -50,7 +50,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr ¶, uint32_t
|
|||
if (node_prim->name() == DEPEND && node_pair.second != 1) {
|
||||
continue;
|
||||
}
|
||||
if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) {
|
||||
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
|
||||
(void)cnode_set.emplace(cnode);
|
||||
} else {
|
||||
auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1);
|
||||
|
@ -98,11 +98,12 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursi
|
|||
return cnode_dist;
|
||||
}
|
||||
|
||||
auto operator_info = cnode->GetUserData<OperatorInfo>();
|
||||
MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode)
|
||||
<< " operator_info: " << (cnode->operator_info() != nullptr);
|
||||
<< " operator_info: " << (operator_info != nullptr);
|
||||
|
||||
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
|
||||
auto cost = cnode->operator_info()->GetForwardMemoryCostFromCNode();
|
||||
if (IsParallelCareNode(cnode) && (operator_info != nullptr)) {
|
||||
auto cost = operator_info->GetForwardMemoryCostFromCNode();
|
||||
MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost;
|
||||
|
||||
if (allreduce_graph_.NodeInGraph(cnode)) {
|
||||
|
|
|
@ -83,7 +83,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) {
|
|||
}
|
||||
auto para_ptr = node_ptr->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(para_ptr);
|
||||
auto layout_ptr = para_ptr->tensor_layout();
|
||||
auto layout_ptr = para_ptr->GetUserData<TensorLayout>();
|
||||
if (layout_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "layout_ptr is nullptr!";
|
||||
return FAILED;
|
||||
|
|
|
@ -37,7 +37,7 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
|
|||
|
||||
for (auto para : graph_params) {
|
||||
std::string name = std::static_pointer_cast<Parameter>(para)->name();
|
||||
std::shared_ptr<parallel::TensorLayout> tensor_layout = std::static_pointer_cast<Parameter>(para)->tensor_layout();
|
||||
auto tensor_layout = para->GetUserData<parallel::TensorLayout>();
|
||||
if (tensor_layout == nullptr) {
|
||||
MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name;
|
||||
} else {
|
||||
|
@ -70,7 +70,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr &graph) {
|
|||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto distributed_operation_info = cnode->operator_info();
|
||||
auto distributed_operation_info = cnode->GetUserData<OperatorInfo>();
|
||||
if (distributed_operation_info != nullptr) {
|
||||
auto strategyPtr = distributed_operation_info->strategy();
|
||||
if (strategyPtr != nullptr) {
|
||||
|
|
|
@ -163,6 +163,9 @@ class OperatorInfo {
|
|||
const std::string &type() const { return type_; }
|
||||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "OpInfo";
|
||||
|
||||
protected:
|
||||
// needed by rec_parser
|
||||
std::string type_;
|
||||
|
|
|
@ -435,7 +435,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
|
|||
std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
|
||||
|
||||
entire_costgraph->AddOperator(operator_info);
|
||||
(void)cnode->set_operator_info(operator_info);
|
||||
cnode->SetUserData<OperatorInfo>(operator_info);
|
||||
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
|
||||
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
|
||||
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
|
||||
|
@ -501,7 +501,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|||
std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
|
||||
|
||||
entire_costgraph->AddOperator(operator_info);
|
||||
(void)cnode->set_operator_info(operator_info);
|
||||
cnode->SetUserData<OperatorInfo>(operator_info);
|
||||
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
|
||||
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
|
||||
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
|
||||
|
@ -520,7 +520,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|||
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
|
||||
<< " does not match the Prim: " << prim->name();
|
||||
}
|
||||
(void)cnode->set_operator_info(current_op_ptr);
|
||||
cnode->SetUserData<OperatorInfo>(current_op_ptr);
|
||||
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
|
||||
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
|
||||
<< " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
|
||||
|
@ -549,6 +549,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
size_t edge_count = 0;
|
||||
|
||||
auto node_op_info = cnode->GetUserData<OperatorInfo>();
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto prev_cnode = inputs[i]->cast<CNodePtr>();
|
||||
bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
|
||||
|
@ -563,8 +565,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
(IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
|
||||
while (bool_result) {
|
||||
if (IsAutoParallelCareNode(prev_cnode)) {
|
||||
std::string edge_name =
|
||||
prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name();
|
||||
auto prev_op_info = prev_cnode->GetUserData<OperatorInfo>();
|
||||
std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name();
|
||||
// If the edge between these two operators already has been added, then the edge will not be added again.
|
||||
if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) {
|
||||
break;
|
||||
|
@ -577,22 +579,20 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
if (follow_strategy) {
|
||||
// Redistribution in not allowed on the edge.
|
||||
// Elementwise operators have the same strategy as their previous operators.
|
||||
edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(),
|
||||
output_index, i - 1, false, true);
|
||||
edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, i - 1, false, true);
|
||||
} else {
|
||||
edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(),
|
||||
output_index, i - 1, false);
|
||||
edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, i - 1, false);
|
||||
}
|
||||
|
||||
// Init costs for this edge
|
||||
if (edge_ptr->InitEdgeCost() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Edge cost initialization failed";
|
||||
}
|
||||
cnode->operator_info()->AddPrevEdge(edge_ptr);
|
||||
prev_cnode->operator_info()->AddSuccEdge(edge_ptr);
|
||||
entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr);
|
||||
MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and "
|
||||
<< cnode->operator_info()->name();
|
||||
node_op_info->AddPrevEdge(edge_ptr);
|
||||
prev_op_info->AddSuccEdge(edge_ptr);
|
||||
entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
|
||||
MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and "
|
||||
<< node_op_info->name();
|
||||
edge_count++;
|
||||
|
||||
break;
|
||||
|
@ -633,7 +633,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
(IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name();
|
||||
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Constructing edges for cost graph ends.";
|
||||
|
@ -750,7 +750,8 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
for (auto &target : target_set) {
|
||||
auto target_cnode = target.first->cast<CNodePtr>();
|
||||
auto input_index = target.second;
|
||||
(void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name());
|
||||
(void)target_without_duplicate.insert(std::to_string(input_index) +
|
||||
target_cnode->GetUserData<OperatorInfo>()->name());
|
||||
}
|
||||
if (target_without_duplicate.size() <= 1) {
|
||||
continue;
|
||||
|
@ -830,24 +831,24 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
auto target_cnode = target.first->cast<CNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0));
|
||||
auto input_index = target.second;
|
||||
auto target_op_info = target_cnode->GetUserData<OperatorInfo>();
|
||||
|
||||
std::string edge_name =
|
||||
std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name();
|
||||
std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name();
|
||||
// If the edge between these two operators already has been added, then the edge will not be added again.
|
||||
if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) {
|
||||
continue;
|
||||
}
|
||||
std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>(
|
||||
edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true);
|
||||
std::shared_ptr<Edge> edge_ptr =
|
||||
std::make_shared<Edge>(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true);
|
||||
|
||||
if (edge_ptr->InitEdgeCost() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Edge cost initialization failed";
|
||||
}
|
||||
target_cnode->operator_info()->AddPrevEdge(edge_ptr);
|
||||
target_op_info->AddPrevEdge(edge_ptr);
|
||||
tmp_identity_ptr->AddSuccEdge(edge_ptr);
|
||||
entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr);
|
||||
entire_costgraph->AddEdge(tmp_identity_ptr, target_op_info, edge_ptr);
|
||||
MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and "
|
||||
<< target_cnode->operator_info()->name();
|
||||
<< target_op_info->name();
|
||||
add_identity_edge = true;
|
||||
}
|
||||
if (new_identity && add_identity_edge) {
|
||||
|
@ -861,20 +862,13 @@ bool FindReshape(const CNodePtr &cnode) {
|
|||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return false;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) {
|
||||
if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) {
|
||||
return false;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
OperatorInfoPtr operator_info = cnode->operator_info();
|
||||
if (operator_info == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
|
||||
}
|
||||
if (prim->name() != RESHAPE) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return (prim->name() == RESHAPE);
|
||||
}
|
||||
|
||||
// find previous node, then obtain its strategy_cost_ vector to get its layout vector.
|
||||
|
@ -890,8 +884,9 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_
|
|||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return false;
|
||||
}
|
||||
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
|
||||
*pre_operator_info = cnode->operator_info();
|
||||
auto node_op_info = cnode->GetUserData<OperatorInfo>();
|
||||
if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) {
|
||||
*pre_operator_info = node_op_info;
|
||||
*out_index = 0;
|
||||
return true;
|
||||
}
|
||||
|
@ -905,8 +900,9 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_
|
|||
MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
|
||||
}
|
||||
CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
|
||||
if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) {
|
||||
*pre_operator_info = pre_cnode->operator_info();
|
||||
auto pre_op_info = pre_cnode->GetUserData<OperatorInfo>();
|
||||
if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) {
|
||||
*pre_operator_info = pre_op_info;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -945,14 +941,15 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator
|
|||
if (node_prim->name() == DEPEND && node_pair.second != 1) {
|
||||
continue;
|
||||
}
|
||||
if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) {
|
||||
auto op_info = use_apply->GetUserData<OperatorInfo>();
|
||||
if (IsParallelCareNode(use_apply) && (op_info != nullptr)) {
|
||||
MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name();
|
||||
*next_operator_info = use_apply->operator_info();
|
||||
*next_operator_info = op_info;
|
||||
*in_index = node_pair.second - 1;
|
||||
return true;
|
||||
}
|
||||
MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
|
||||
<< " " << (use_apply->operator_info() != nullptr);
|
||||
<< " " << (op_info != nullptr);
|
||||
|
||||
if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) {
|
||||
return true;
|
||||
|
@ -973,8 +970,8 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
int32_t out_index = 0;
|
||||
OperatorInfoPtr pre_operator_info;
|
||||
std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
|
||||
auto operator_info = cnode->GetUserData<OperatorInfo>();
|
||||
if (pre_node->isa<Parameter>()) {
|
||||
OperatorInfoPtr operator_info = cnode->operator_info();
|
||||
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
|
||||
reshape_info->SetCostForReshapeWithParameter();
|
||||
pre_operator_info = reshape_info;
|
||||
|
@ -995,7 +992,6 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
}
|
||||
// set input_layout and output_layout for reshape.
|
||||
// init reshape and set cost for each input_layout and output_layout.
|
||||
OperatorInfoPtr operator_info = cnode->operator_info();
|
||||
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
|
||||
reshape_info->set_pre_operator_name(pre_operator_info->name());
|
||||
reshape_info->set_pre_operator_index(out_index);
|
||||
|
|
|
@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
|
|||
if (!IsParallelCareNode(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
OperatorInfoPtr distribute_operator = node->operator_info();
|
||||
OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>();
|
||||
if (distribute_operator == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr";
|
||||
}
|
||||
|
@ -415,7 +415,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) {
|
|||
if (prim->name() == GET_NEXT) {
|
||||
return true;
|
||||
}
|
||||
if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) {
|
||||
if ((prim->name() == CAST) && !cnode->HasUserData<OperatorInfo>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -452,7 +452,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
|
|||
if (node_prim->name() == DEPEND && node_pair.second != 1) {
|
||||
continue;
|
||||
}
|
||||
if (IsParallelCareNode(use_cnode) && (use_cnode->operator_info() != nullptr)) {
|
||||
if (IsParallelCareNode(use_cnode) && use_cnode->HasUserData<OperatorInfo>()) {
|
||||
Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution,
|
||||
pre_node);
|
||||
} else {
|
||||
|
@ -465,7 +465,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
|
|||
void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(next_node);
|
||||
OperatorInfoPtr op_info = next_node->operator_info();
|
||||
OperatorInfoPtr op_info = next_node->GetUserData<OperatorInfo>();
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
|
||||
// If the shape of tensor is [] or [1], no need to split it.
|
||||
|
@ -590,7 +590,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) {
|
|||
|
||||
void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
|
||||
// step1:get graph manager distribute_operator
|
||||
OperatorInfoPtr distribute_operator = node->operator_info();
|
||||
OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>();
|
||||
if (distribute_operator == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr";
|
||||
}
|
||||
|
@ -628,7 +628,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
|
|||
(void)prim->SetAttrs(attrs);
|
||||
}
|
||||
if (index == replace_op.size() - 1) {
|
||||
(void)replace_node->set_operator_info(node->operator_info());
|
||||
replace_node->SetUserData<OperatorInfo>(node->GetUserData<OperatorInfo>());
|
||||
}
|
||||
replace_node->set_in_forward_flag(true);
|
||||
replace_input[0]->set_scope(scope);
|
||||
|
@ -708,7 +708,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) {
|
|||
auto pre_cnode = pre_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(pre_cnode);
|
||||
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||
if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
|
||||
if (pre_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) {
|
||||
pre_node = pre_cnode->input(1);
|
||||
}
|
||||
|
||||
|
@ -1204,7 +1204,7 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) {
|
|||
if (node_prim->name() == DEPEND && node_pair.second != 1) {
|
||||
continue;
|
||||
}
|
||||
if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) {
|
||||
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
|
||||
return node_pair;
|
||||
} else if (FindParallelCareNode(node_pair.first).first != nullptr) {
|
||||
return FindParallelCareNode(node_pair.first);
|
||||
|
@ -1254,7 +1254,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
|
|||
MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
|
||||
CNodePtr cnode = res.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
OperatorInfoPtr distribute_operator = cnode->operator_info();
|
||||
OperatorInfoPtr distribute_operator = cnode->GetUserData<OperatorInfo>();
|
||||
if (distribute_operator == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
|
||||
}
|
||||
|
@ -1277,7 +1277,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
|
|||
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
|
||||
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(parameter_ptr);
|
||||
parameter_ptr->set_tensor_layout(std::make_shared<TensorLayout>(tensor_layout));
|
||||
parameter_ptr->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
|
||||
}
|
||||
|
||||
void CoverSliceShape(const FuncGraphPtr &root) {
|
||||
|
@ -1365,7 +1365,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|||
|
||||
if (found_be_cloned_parameter) {
|
||||
// set the shape and tensor layout for cloned parameter
|
||||
cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout());
|
||||
cloned_parameter->SetUserData<TensorLayout>(cloned_from_parameter->GetUserData<TensorLayout>());
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
|
||||
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
|
||||
auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
|
||||
|
@ -1464,7 +1464,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
(*operator_).set_outputs_dtype(cnode->Type());
|
||||
(*operator_).set_cnode(cnode);
|
||||
if (prim->name() == RESHAPE) {
|
||||
(void)cnode->set_operator_info(operator_);
|
||||
cnode->SetUserData<OperatorInfo>(operator_);
|
||||
continue;
|
||||
}
|
||||
// load strategy checkpoint
|
||||
|
@ -1499,7 +1499,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
if (operator_->Init(strategyPtr) == FAILED) {
|
||||
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
|
||||
}
|
||||
(void)cnode->set_operator_info(operator_);
|
||||
cnode->SetUserData<OperatorInfo>(operator_);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr";
|
||||
}
|
||||
|
@ -1542,13 +1542,13 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) {
|
|||
if (node_prim->name() == DEPEND && node_pair.second != 1) {
|
||||
continue;
|
||||
}
|
||||
if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) {
|
||||
if (IsParallelCareNode(use_apply) && use_apply->HasUserData<OperatorInfo>()) {
|
||||
MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name();
|
||||
auto layout = GetInputLayoutFromCNode(node_pair);
|
||||
return std::make_shared<TensorLayout>(layout);
|
||||
}
|
||||
MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
|
||||
<< " " << (use_apply->operator_info() != nullptr);
|
||||
<< " " << use_apply->HasUserData<OperatorInfo>();
|
||||
|
||||
auto layout_ptr = FindNextLayout(use_apply);
|
||||
if (layout_ptr) {
|
||||
|
@ -1580,7 +1580,7 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
|
|||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return nullptr;
|
||||
}
|
||||
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
|
||||
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
|
||||
auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index);
|
||||
if (!layout_ptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
|
||||
|
@ -1624,7 +1624,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
|
|||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return nullptr;
|
||||
}
|
||||
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
|
||||
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
|
||||
auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
|
||||
if (!layout_ptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
|
||||
|
@ -1664,12 +1664,12 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
continue;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) {
|
||||
if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) {
|
||||
continue;
|
||||
}
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
OperatorInfoPtr operator_info = cnode->operator_info();
|
||||
OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>();
|
||||
if (operator_info == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
|
||||
}
|
||||
|
@ -1714,7 +1714,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
|||
|
||||
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||
// return -> cast
|
||||
if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
|
||||
if (current_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) {
|
||||
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(pre_cnode);
|
||||
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||
|
@ -1771,7 +1771,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
OperatorInfoPtr operator_info = loss_cnode->operator_info();
|
||||
OperatorInfoPtr operator_info = loss_cnode->GetUserData<OperatorInfo>();
|
||||
MS_EXCEPTION_IF_NULL(operator_info);
|
||||
TensorInfo loss_grad_tensor_info;
|
||||
size_t op_output_size = operator_info->outputs_tensor_info().size();
|
||||
|
@ -1809,7 +1809,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
|
|||
if (sens_tensor_node->isa<Parameter>()) {
|
||||
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
|
||||
MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString();
|
||||
sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout));
|
||||
sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
|
||||
}
|
||||
MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens";
|
||||
return;
|
||||
|
@ -1834,7 +1834,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
|
|||
cloned_abstract->set_shape(parallel_shape);
|
||||
sens_tensor_node->set_abstract(cloned_abstract);
|
||||
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
|
||||
sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout));
|
||||
sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
|
||||
return;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now.";
|
||||
|
@ -2125,7 +2125,7 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
OperatorInfoPtr operator_info = cnode->operator_info();
|
||||
OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>();
|
||||
if (operator_info) {
|
||||
if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
continue;
|
||||
|
|
|
@ -83,6 +83,9 @@ class TensorLayout {
|
|||
|
||||
TensorLayout SqueezeShape() const;
|
||||
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "TLayout";
|
||||
|
||||
private:
|
||||
std::shared_ptr<TensorLayout> ExpandTensorShapeWithoutExtendDeviceArrangement(
|
||||
const Arrangement &expanded_shape) const;
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPERATOR_OPS_H_
|
||||
#define MINDSPORE_CORE_OPERATOR_OPS_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace prim {
|
||||
// Maths
|
||||
inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");
|
||||
inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
|
||||
inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
|
||||
inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");
|
||||
inline const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad");
|
||||
inline const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean");
|
||||
inline const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum");
|
||||
inline const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll");
|
||||
inline const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax");
|
||||
inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin");
|
||||
inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg");
|
||||
inline const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub");
|
||||
inline const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
|
||||
inline const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum");
|
||||
inline const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum");
|
||||
inline const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square");
|
||||
inline const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum");
|
||||
inline const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd");
|
||||
inline const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar");
|
||||
inline const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd");
|
||||
inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub");
|
||||
inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow");
|
||||
inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv");
|
||||
inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
|
||||
inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal");
|
||||
inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims");
|
||||
|
||||
// Statements
|
||||
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");
|
||||
inline const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch");
|
||||
inline const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer");
|
||||
inline const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign");
|
||||
inline const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd");
|
||||
inline const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub");
|
||||
inline const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select");
|
||||
inline const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call");
|
||||
|
||||
// Structures
|
||||
inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
|
||||
inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
|
||||
inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple");
|
||||
inline const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict");
|
||||
inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
|
||||
inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg");
|
||||
inline const PrimitivePtr kPrimMakeSlice = std::make_shared<Primitive>("make_slice");
|
||||
inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
|
||||
inline const PrimitivePtr kPrimTupleGetItem = std::make_shared<Primitive>("tuple_getitem");
|
||||
inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem");
|
||||
inline const PrimitivePtr kPrimArrayGetItem = std::make_shared<Primitive>("array_getitem");
|
||||
inline const PrimitivePtr kPrimTupleSetItem = std::make_shared<Primitive>("tuple_setitem");
|
||||
inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem");
|
||||
inline const PrimitivePtr kPrimArraySetItem = std::make_shared<Primitive>("array_setitem");
|
||||
inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem");
|
||||
inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem");
|
||||
inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append");
|
||||
inline const PrimitivePtr kPrimGetAttr = std::make_shared<Primitive>("getattr");
|
||||
inline const PrimitivePtr kPrimTupleLen = std::make_shared<Primitive>("tuple_len");
|
||||
inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len");
|
||||
inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");
|
||||
inline const PrimitivePtr kPrimArrayLen = std::make_shared<Primitive>("array_len");
|
||||
inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
|
||||
inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
|
||||
inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed");
|
||||
inline const PrimitivePtr kPrimTileShape = std::make_shared<Primitive>("tile_shape");
|
||||
inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape");
|
||||
inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div");
|
||||
inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array");
|
||||
inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul");
|
||||
inline const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared<Primitive>("generate_shape_index");
|
||||
inline const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index");
|
||||
inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal");
|
||||
inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal");
|
||||
inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range");
|
||||
inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient");
|
||||
inline const PrimitivePtr kPrimExtractKeywordArg = std::make_shared<Primitive>("extract_keyword_arg");
|
||||
|
||||
// Debug ops
|
||||
inline const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
|
||||
inline const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary");
|
||||
inline const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
|
||||
inline const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
|
||||
inline const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug");
|
||||
|
||||
// Other miscellaneous
|
||||
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
|
||||
inline const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend");
|
||||
inline const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("Partial");
|
||||
inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");
|
||||
inline const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem");
|
||||
inline const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem");
|
||||
inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
|
||||
inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
|
||||
inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
|
||||
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
|
||||
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
|
||||
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
|
||||
inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward");
|
||||
inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");
|
||||
inline const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape");
|
||||
inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
|
||||
inline const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print");
|
||||
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
|
||||
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
|
||||
inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
|
||||
inline const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend");
|
||||
inline const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_");
|
||||
inline const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
|
||||
inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
|
||||
inline const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
|
||||
inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
|
||||
inline const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
|
||||
inline const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat");
|
||||
|
||||
class DoSignaturePrimitive : public Primitive {
|
||||
public:
|
||||
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
|
||||
: Primitive("S-Prim-" + name), function_(function) {}
|
||||
|
||||
~DoSignaturePrimitive() override = default;
|
||||
|
||||
MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive)
|
||||
|
||||
const ValuePtr function() const { return function_; }
|
||||
|
||||
private:
|
||||
ValuePtr function_;
|
||||
};
|
||||
using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>;
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPERATOR_OPS_H_
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_USER_DATA_H_
|
||||
#define MINDSPORE_CORE_USER_DATA_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
class UserData {
|
||||
public:
|
||||
template <typename T>
|
||||
void set(const std::string &key, const std::shared_ptr<T> &value) {
|
||||
if (value == nullptr) {
|
||||
data_.erase(key);
|
||||
} else {
|
||||
data_.insert_or_assign(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> get(const std::string &key) const {
|
||||
auto iter = data_.find(key);
|
||||
if (iter == data_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::static_pointer_cast<T>(iter->second);
|
||||
}
|
||||
|
||||
bool has(const std::string &key) const { return data_.find(key) != data_.end(); }
|
||||
|
||||
private:
|
||||
std::map<std::string, std::shared_ptr<void>> data_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_USER_DATA_H_
|
|
@ -26,7 +26,6 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support intermediate representation definition
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include <utility>
|
||||
|
||||
#include "base/base.h"
|
||||
#include "base/user_data.h"
|
||||
#include "ir/kernel_info_dev.h"
|
||||
#include "ir/scope.h"
|
||||
#include "debug/info.h"
|
||||
|
@ -41,12 +42,6 @@
|
|||
// ANode: Atomic Node
|
||||
// CNode: Complex Node
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class TensorLayout;
|
||||
class OperatorInfo;
|
||||
} // namespace parallel
|
||||
using OperatorInfoPtr = std::shared_ptr<parallel::OperatorInfo>;
|
||||
|
||||
namespace abstract {
|
||||
class BaseShape;
|
||||
class AbstractBase;
|
||||
|
@ -157,6 +152,31 @@ class AnfNode : public Base {
|
|||
}
|
||||
size_t seen_{0};
|
||||
|
||||
template <typename T>
|
||||
void SetUserData(const std::string &key, const std::shared_ptr<T> &value) {
|
||||
user_data_.set<T>(key, value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SetUserData(const std::shared_ptr<T> &value) {
|
||||
user_data_.set<T>(T::key, value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> GetUserData(const std::string &key) const {
|
||||
return user_data_.get<T>(key);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> GetUserData() const {
|
||||
return user_data_.get<T>(T::key);
|
||||
}
|
||||
|
||||
bool HasUserData(const std::string &key) const { return user_data_.has(key); }
|
||||
|
||||
template <typename T>
|
||||
bool HasUserData() const { return user_data_.has(T::key); }
|
||||
|
||||
protected:
|
||||
// Hold a weak ref to Graph as Graph also hold ref to AnfNode.
|
||||
// Otherwise, func_graph_ and AnfNode will make a reference cycle.
|
||||
|
@ -170,6 +190,7 @@ class AnfNode : public Base {
|
|||
std::hash<const AnfNode *> hash_;
|
||||
ScopePtr scope_;
|
||||
KernelInfoDevicePtr kernel_info_;
|
||||
UserData user_data_;
|
||||
};
|
||||
|
||||
// CNode represents the complex node with a set of arguments.
|
||||
|
@ -212,9 +233,6 @@ class CNode : public AnfNode {
|
|||
std::string DebugString(int recursive_level = 1) const override;
|
||||
std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); }
|
||||
|
||||
OperatorInfoPtr set_operator_info(const OperatorInfoPtr &operator_info);
|
||||
OperatorInfoPtr operator_info() { return operator_info_; }
|
||||
|
||||
void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; }
|
||||
bool in_forward_flag() const { return in_forward_flag_; }
|
||||
|
||||
|
@ -224,7 +242,6 @@ class CNode : public AnfNode {
|
|||
std::vector<AnfNodePtr> inputs_;
|
||||
VarPtr func_graph_as_var_;
|
||||
bool stop_gradient_;
|
||||
OperatorInfoPtr operator_info_ = nullptr;
|
||||
bool in_forward_flag_ = false;
|
||||
};
|
||||
|
||||
|
@ -244,7 +261,7 @@ class ANode : public AnfNode {
|
|||
class Parameter : public ANode {
|
||||
public:
|
||||
explicit Parameter(const FuncGraphPtr &func_graph)
|
||||
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), tensor_layout_(nullptr) {}
|
||||
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr) {}
|
||||
~Parameter() override = default;
|
||||
MS_DECLARE_PARENT(Parameter, ANode);
|
||||
|
||||
|
@ -261,11 +278,6 @@ class Parameter : public ANode {
|
|||
}
|
||||
ParamValuePtr default_param() const { return default_param_; }
|
||||
|
||||
std::shared_ptr<parallel::TensorLayout> tensor_layout() const { return tensor_layout_; }
|
||||
void set_tensor_layout(const std::shared_ptr<parallel::TensorLayout> &tensor_layout) {
|
||||
tensor_layout_ = tensor_layout;
|
||||
}
|
||||
|
||||
bool operator==(const AnfNode &other) const override {
|
||||
if (!other.isa<Parameter>()) {
|
||||
return false;
|
||||
|
@ -281,7 +293,6 @@ class Parameter : public ANode {
|
|||
std::string name_;
|
||||
bool has_default_;
|
||||
ParamValuePtr default_param_;
|
||||
std::shared_ptr<parallel::TensorLayout> tensor_layout_;
|
||||
};
|
||||
using ParameterPtr = std::shared_ptr<Parameter>;
|
||||
|
||||
|
|
|
@ -23,8 +23,7 @@
|
|||
|
||||
#include "ir/visitor.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "debug/label.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -37,18 +36,6 @@ std::string AnfNode::ToString() const {
|
|||
return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
|
||||
}
|
||||
|
||||
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
|
||||
if (operator_info_ != nullptr) {
|
||||
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
|
||||
<< ", using the new one: " << operator_info->name();
|
||||
auto old_ptr = operator_info_;
|
||||
operator_info_ = operator_info;
|
||||
return old_ptr;
|
||||
}
|
||||
operator_info_ = operator_info;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string CNode::fullname_with_scope() {
|
||||
// if full name is set, return its name immediately
|
||||
if (!fullname_with_scope_.empty()) {
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
|
||||
#include "debug/trace.h"
|
||||
#include "ir/manager.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "utils/ordered_set.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
|
||||
#include "ir/manager.h"
|
||||
#include "ir/param_value.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/profile.h"
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
|
||||
#include "ir/manager.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/ordered_set.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "utils/profile.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
|
|
|
@ -17,10 +17,8 @@
|
|||
*/
|
||||
|
||||
#include "ir/meta_func_graph.h"
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
#include "pipeline/jit/static_analysis/abstract_function.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
||||
// namespace to support intermediate representation definition
|
||||
namespace mindspore {
|
||||
|
|
|
@ -22,9 +22,9 @@
|
|||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/optimizer_caller.h"
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
///
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
|
||||
#include "ir/dtype/type.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "utils/base_ref_extends.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#include <mutex>
|
||||
#include <utility>
|
||||
#include "ir/signature.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "./common.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
|
|
|
@ -28,7 +28,6 @@
|
|||
#include <type_traits>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "runtime/device/device_address.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -153,7 +153,7 @@ TEST_F(TestStepAutoParallel, test_create_op_instance) {
|
|||
StrategyPtr strategyPtr;
|
||||
|
||||
std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape);
|
||||
node->set_operator_info(matmul_info);
|
||||
node->SetUserData<OperatorInfo>(matmul_info);
|
||||
std::string name_expect = "MatMulInfo00";
|
||||
std::string name_test = matmul_info->name();
|
||||
ASSERT_EQ(name_expect, name_test);
|
||||
|
|
|
@ -525,8 +525,8 @@ TEST_F(TestStepParallel, GetTensorInLayout) {
|
|||
std::vector<Shapes> shape = {inputs_shape, outputs_shape};
|
||||
OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape);
|
||||
matmul_info->Init(strategyPtr);
|
||||
node->set_operator_info(matmul_info);
|
||||
OperatorInfoPtr distribute_operator_pre = node->operator_info();
|
||||
node->SetUserData<OperatorInfo>(matmul_info);
|
||||
OperatorInfoPtr distribute_operator_pre = node->GetUserData<OperatorInfo>();
|
||||
TensorLayout tensorlayout_e;
|
||||
std::vector<int32_t> array = {64, 64};
|
||||
TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre);
|
||||
|
|
Loading…
Reference in New Issue