add some ops

This commit is contained in:
jinyaohui 2021-02-03 19:47:53 +08:00
parent f9d9bba927
commit d9be0c102d
608 changed files with 28986 additions and 1834 deletions

View File

@ -333,14 +333,13 @@ if(CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(mindspore mindspore::pybind11_module)
target_link_libraries(mindspore mindspore_gvar)
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive)
elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
target_link_libraries(mindspore mindspore::pybind11_module)
target_link_libraries(mindspore mindspore_gvar)
target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore -Wl,-noall_load)
else()
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf
mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
else ()
if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache)
if(${ENABLE_IBVERBS} STREQUAL "ON")
target_link_libraries(mindspore ibverbs rdmacm)

View File

@ -717,7 +717,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
prim::kPrimMinimumGrad,
prim::kPrimGkDropout,
prim::kPrimDropoutGrad,
prim::kPrimSoftMax,
prim::kPrimSoftmax,
prim::kPrimLayerNorm,
prim::kPrimLayerNormGrad,
#endif

View File

@ -20,7 +20,7 @@
#include <unordered_map>
#include <utility>
#include "c_ops/primitive_c.h"
#include "ops/primitive_c.h"
#include "ir/manager.h"
#include "abstract/utils.h"
#include "backend/kernel_compiler/common_utils.h"
@ -1750,8 +1750,8 @@ void SessionBasic::RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<
input_abstracts.emplace_back(abstract);
}
auto prim = AnfAlgo::GetCNodePrimitive(node);
if (prim->isa<PrimitiveC>()) {
auto prim_c = prim->cast<std::shared_ptr<PrimitiveC>>();
if (prim->isa<ops::PrimitiveC>()) {
auto prim_c = prim->cast<std::shared_ptr<ops::PrimitiveC>>();
MS_EXCEPTION_IF_NULL(prim_c);
auto abstract = prim_c->Infer(input_abstracts);
node->set_abstract(abstract);

View File

@ -8,16 +8,16 @@ endif()
message("************ build core ***************")
file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"abstract/*.cc"
"base/*.cc"
"c_ops/*.cc"
"ir/*.cc"
"utils/*.cc"
"load_mindir/*.cc"
)
"abstract/*.cc"
"base/*.cc"
"ops/*.cc"
"ir/*.cc"
"utils/*.cc"
"load_mindir/*.cc"
)
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes -DHAVE_SNPRINTF")
add_compile_definitions(BUILDING_DLL)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes -DHAVE_SNPRINTF")
add_compile_definitions(BUILDING_DLL)
elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
-Wuser-defined-warnings -Winconsistent-missing-override -Wno-delete-non-abstract-non-virtual-dtor")
@ -28,5 +28,5 @@ add_library(mindspore_core STATIC ${CORE_SRC_LIST})
target_link_libraries(mindspore_core PRIVATE mindspore_gvar)
if(USE_GLOG)
target_link_libraries(mindspore_core PRIVATE mindspore::glog)
target_link_libraries(mindspore_core PRIVATE mindspore::glog)
endif()

View File

@ -91,7 +91,9 @@ inline const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelS
inline const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet");
// Arrays
inline const PrimitivePtr kPrimBroadcastTo = std::make_shared<Primitive>("BroadcastTo");
inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array");
inline const PrimitivePtr kPrimTopK = std::make_shared<Primitive>("TopK");
inline const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar");
inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape");
inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map");
@ -99,17 +101,25 @@ inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_
inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
inline const PrimitivePtr kPrimUnsqueeze = std::make_shared<Primitive>("Unsqueeze");
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
inline const PrimitivePtr kPrimGatherD = std::make_shared<Primitive>("GatherD");
inline const PrimitivePtr kPrimGather = std::make_shared<Primitive>(kGather);
inline const PrimitivePtr kPrimGather = std::make_shared<Primitive>("Gather");
inline const PrimitivePtr kPrimGatherND = std::make_shared<Primitive>("GatherND");
inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared<Primitive>("SparseGatherV2");
inline const PrimitivePtr kPrimSparseToDense = std::make_shared<Primitive>("SparseToDense");
inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
inline const PrimitivePtr kPrimStridedSlice = std::make_shared<Primitive>("StridedSlice");
inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>("DynamicShape");
inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup");
inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad");
inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
inline const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax");
inline const PrimitivePtr kPrimArgMin = std::make_shared<Primitive>("Argmin");
inline const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack");
inline const PrimitivePtr kPrimUnpack = std::make_shared<Primitive>("Unpack");
inline const PrimitivePtr kPrimUnstack = std::make_shared<Primitive>("Unstack");
inline const PrimitivePtr kPrimUnsortedSegmentMax = std::make_shared<Primitive>("UnsortedSegmentMax");
inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum");
inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
@ -123,6 +133,7 @@ inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("Cac
inline const PrimitivePtr kPrimDynamicAssign = std::make_shared<Primitive>("DynamicAssign");
inline const PrimitivePtr kPrimPadAndShift = std::make_shared<Primitive>("PadAndShift");
inline const PrimitivePtr kPrimSlice = std::make_shared<Primitive>("Slice");
inline const PrimitivePtr kPrimSliceFusion = std::make_shared<Primitive>("SliceFusion");
inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2");
@ -145,16 +156,36 @@ inline const PrimitivePtr kPrimSequenceMask = std::make_shared<Primitive>("Seque
inline const PrimitivePtr kPrimRange = std::make_shared<Primitive>("Range");
inline const PrimitivePtr kPrimSpaceToBatchND = std::make_shared<Primitive>("SpaceToBatchND");
inline const PrimitivePtr kPrimBatchToSpaceND = std::make_shared<Primitive>("BatchToSpaceND");
inline const PrimitivePtr kPrimDepthToSpace = std::make_shared<Primitive>("DepthToSpace");
inline const PrimitivePtr kPrimBatchToSpace = std::make_shared<Primitive>("BatchToSpace");
inline const PrimitivePtr kPrimSpaceToBatch = std::make_shared<Primitive>("SpaceToBatch");
inline const PrimitivePtr kPrimScatterNd = std::make_shared<Primitive>("ScatterNd");
inline const PrimitivePtr kPrimConstantOfShape = std::make_shared<Primitive>("ConstantOfShape");
inline const PrimitivePtr kPrimSquaredDifference = std::make_shared<Primitive>("SquaredDifference");
inline const PrimitivePtr kPrimReverseV2 = std::make_shared<Primitive>("ReverseV2");
inline const PrimitivePtr kPrimReverseSequence = std::make_shared<Primitive>("ReverseSequence");
inline const PrimitivePtr kPrimRank = std::make_shared<Primitive>("Rank");
inline const PrimitivePtr kPrimResizeBilinear = std::make_shared<Primitive>("ResizeBilinear");
// NN
inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam");
inline const PrimitivePtr kPrimAudioSpectrogram = std::make_shared<Primitive>("AudioSpectrogram");
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
inline const PrimitivePtr kPrimSoftMax = std::make_shared<Primitive>("Softmax");
inline const PrimitivePtr kPrimCrop = std::make_shared<Primitive>("Crop");
inline const PrimitivePtr kPrimFlattenGrad = std::make_shared<Primitive>("FlattenGrad");
inline const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax");
inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropy = std::make_shared<Primitive>("SparseSoftmaxCrossEntropy");
inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax");
inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad");
inline const PrimitivePtr kPrimLstm = std::make_shared<Primitive>("Lstm");
inline const PrimitivePtr kPrimTan = std::make_shared<Primitive>("Tan");
inline const PrimitivePtr kPrimAtan = std::make_shared<Primitive>("Atan");
inline const PrimitivePtr kPrimAsin = std::make_shared<Primitive>("Asin");
inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh");
inline const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad");
inline const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling");
inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad");
inline const PrimitivePtr kPrimROIPooling = std::make_shared<Primitive>("ROIPooling");
inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool");
inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad");
inline const PrimitivePtr kPrimMaxPoolWithArgmax = std::make_shared<Primitive>("MaxPoolWithArgmax");
@ -168,6 +199,9 @@ inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("Fu
inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared<Primitive>("FusedBatchNormEx");
inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
inline const PrimitivePtr kPrimFullConnection = std::make_shared<Primitive>("FullConnection");
inline const PrimitivePtr kPrimConv2DTranspose = std::make_shared<Primitive>("Conv2DTranspose");
inline const PrimitivePtr kPrimGroupConv2DGradInput = std::make_shared<Primitive>("GroupConv2DGradInput");
inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad");
inline const PrimitivePtr kPrimFusedBatchNormGradEx = std::make_shared<Primitive>("FusedBatchNormGradEx");
inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
@ -179,21 +213,34 @@ inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>
inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter");
inline const PrimitivePtr kPrimConv3DBackpropInput = std::make_shared<Primitive>("Conv3DBackpropInput");
inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared<Primitive>("Conv3DBackpropFilter");
inline const PrimitivePtr kPrimCustomNormalize = std::make_shared<Primitive>("CustomNormalize");
inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative");
inline const PrimitivePtr kPrimCTCGreedyDecoder = std::make_shared<Primitive>("CTCGreedyDecoder");
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter =
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter");
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput =
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput");
inline const PrimitivePtr kPrimDetectionPostProcess = std::make_shared<Primitive>("DetectionPostProcess");
inline const PrimitivePtr kPrimBiasAdd = std::make_shared<Primitive>("BiasAdd");
inline const PrimitivePtr kPrimBiasGrad = std::make_shared<Primitive>("BiasGrad");
inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad");
inline const PrimitivePtr kPrimBiasSubGrad = std::make_shared<Primitive>("BiasSubGrad");
inline const PrimitivePtr kPrimBinaryCrossEntropy = std::make_shared<Primitive>("BinaryCrossEntropy");
inline const PrimitivePtr kPrimBinaryCrossEntropyGrad = std::make_shared<Primitive>("BinaryCrossEntropyGrad");
inline const PrimitivePtr kPrimSmoothL1Loss = std::make_shared<Primitive>("SmoothL1Loss");
inline const PrimitivePtr kPrimSmoothL1LossGrad = std::make_shared<Primitive>("SmoothL1LossGrad");
inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits =
std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits");
inline const PrimitivePtr kPrimSigmoidCrossEntropyWithLogits =
std::make_shared<Primitive>("SigmoidCrossEntropyWithLogits");
inline const PrimitivePtr kPrimSigmoidCrossEntropyWithLogitsGrad =
std::make_shared<Primitive>("SigmoidCrossEntropyWithLogitsGrad");
inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits =
std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits");
inline const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum");
inline const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum");
inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm");
inline const PrimitivePtr kPrimLrn = std::make_shared<Primitive>("Lrn");
inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad");
inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop");
inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
@ -204,18 +251,22 @@ inline const PrimitivePtr kPrimDropout = std::make_shared<Primitive>("Dropout");
inline const PrimitivePtr kPrimUniformReal = std::make_shared<Primitive>("UniformReal");
inline const PrimitivePtr kPrimCudnnUniformReal = std::make_shared<Primitive>("CudnnUniformReal");
inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
inline const PrimitivePtr kPrimGeLU = std::make_shared<Primitive>("Gelu");
inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");
inline const PrimitivePtr kPrimFastGelu = std::make_shared<Primitive>("FastGelu");
inline const PrimitivePtr kPrimFastGeluGrad = std::make_shared<Primitive>("FastGeluGrad");
inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
inline const PrimitivePtr kPrimElu = std::make_shared<Primitive>("Elu");
inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6");
inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
inline const PrimitivePtr kPrimPRelu = std::make_shared<Primitive>("PReLU");
inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike");
inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
inline const PrimitivePtr kPrimFakeQuantWithMinMaxVars = std::make_shared<Primitive>("FakeQuantWithMinMaxVars");
inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp");
inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("SparseApplyFtrl");
inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad");
@ -224,6 +275,8 @@ inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive
inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD");
inline const PrimitivePtr kPrimClipByNormNoDivSum = std::make_shared<Primitive>("ClipByNormNoDivSum");
inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorMove");
inline const PrimitivePtr kPrimL2Normalize = std::make_shared<Primitive>("L2Normalize");
inline const PrimitivePtr kPrimCustomExtractFeatures = std::make_shared<Primitive>("CustomExtractFeatures");
// Comm ops
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
@ -239,6 +292,12 @@ inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGathe
inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter");
inline const PrimitivePtr kPrimMemCpyAsync = std::make_shared<Primitive>("memcpy_async");
inline const PrimitivePtr kPrimFill = std::make_shared<Primitive>("Fill");
// Quant ops
inline const PrimitivePtr kPrimBatchNormFold = std::make_shared<Primitive>("BatchNormFold");
inline const PrimitivePtr kPrimFakeQuantWithMinMaxVarsPerChannel =
std::make_shared<Primitive>("FakeQuantWithMinMaxVarsPerChannel");
// Control ops
inline const PrimitivePtr kPrimMerge = std::make_shared<Primitive>("Merge");
// RowTensor
inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared<Primitive>("MakeRowTensor");
inline const PrimitivePtr kPrimRowTensorGetValues = std::make_shared<Primitive>("RowTensorGetValues");
@ -251,12 +310,22 @@ inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitiv
inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
// TensorList
inline const PrimitivePtr kPrimTensorListFromTensor = std::make_shared<Primitive>("TensorListFromTensor");
inline const PrimitivePtr kPrimTensorListReserve = std::make_shared<Primitive>("TensorListReserve");
inline const PrimitivePtr kPrimTensorListStack = std::make_shared<Primitive>("TensorListStack");
inline const PrimitivePtr kPrimTensorListSetItem = std::make_shared<Primitive>("TensorListSetItem");
// Maths
inline const PrimitivePtr kPrimCeil = std::make_shared<Primitive>("Ceil");
inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");
inline const PrimitivePtr kPrimAdd = std::make_shared<Primitive>("Add");
inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
inline const PrimitivePtr kPrimMatrixDiag = std::make_shared<Primitive>("MatrixDiag");
inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");
inline const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad");
inline const PrimitivePtr kPrimReduce = std::make_shared<Primitive>("Reduce");
inline const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean");
inline const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum");
inline const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll");
@ -264,6 +333,8 @@ inline const PrimitivePtr kPrimReduceAny = std::make_shared<Primitive>("ReduceAn
inline const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax");
inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin");
inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg");
inline const PrimitivePtr kPrimSin = std::make_shared<Primitive>("Sin");
inline const PrimitivePtr kPrimCos = std::make_shared<Primitive>("Cos");
inline const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub");
inline const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
inline const PrimitivePtr kPrimDiv = std::make_shared<Primitive>("Div");
@ -279,6 +350,7 @@ inline const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscala
inline const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd");
inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub");
inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow");
inline const PrimitivePtr kPrimPower = std::make_shared<Primitive>("Power");
inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv");
inline const PrimitivePtr kPrimFloorDiv = std::make_shared<Primitive>("FloorDiv");
inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
@ -292,12 +364,13 @@ inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log");
inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt");
inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV");
inline const PrimitivePtr kPrimLinSpace = std::make_shared<Primitive>("LinSpace");
inline const PrimitivePtr kPrimNonMaxSuppression = std::make_shared<Primitive>("NonMaxSuppression");
inline const PrimitivePtr kPrimSign = std::make_shared<Primitive>("Sign");
inline const PrimitivePtr kPrimSquaredDifference = std::make_shared<Primitive>("SquaredDifference");
inline const PrimitivePtr kPrimAsin = std::make_shared<Primitive>("Asin");
inline const PrimitivePtr kPrimACos = std::make_shared<Primitive>("ACos");
inline const PrimitivePtr kPrimAsinGrad = std::make_shared<Primitive>("AsinGrad");
inline const PrimitivePtr kPrimACosGrad = std::make_shared<Primitive>("ACosGrad");
inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod");
inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where");
// Statements
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");
@ -323,6 +396,7 @@ inline const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared<Primitive>(
inline const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index");
// Debug ops
inline const PrimitivePtr kPrimAssert = std::make_shared<Primitive>("Assert");
inline const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
inline const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary");
inline const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
@ -349,6 +423,13 @@ inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
inline const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
inline const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
inline const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat");
inline const PrimitivePtr kPrimLshProjection = std::make_shared<Primitive>("LshProjection");
inline const PrimitivePtr kPrimHashtableLookup = std::make_shared<Primitive>("HashtableLookup");
inline const PrimitivePtr kPrimCustomPredict = std::make_shared<Primitive>("CustomPredict");
inline const PrimitivePtr kPrimStack = std::make_shared<Primitive>("Stack");
inline const PrimitivePtr kPrimPriorBox = std::make_shared<Primitive>("PriorBox");
inline const PrimitivePtr kPrimQuantDTypeCast = std::make_shared<Primitive>("QuantDTypeCast");
inline const PrimitivePtr kPrimWhile = std::make_shared<Primitive>("While");
// Structures
inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
@ -371,7 +452,7 @@ inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
// Other primitive not used by backend but used in core;
// Other primitve not used by backend but used in core;
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
@ -382,6 +463,44 @@ inline const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict
// GraphKernel ops
inline const PrimitivePtr kPrimInplaceAssign = std::make_shared<Primitive>("InplaceAssign");
// Only used in lite
inline const PrimitivePtr kPrimLeakyRelu = std::make_shared<Primitive>("LeakyRelu");
inline const PrimitivePtr kPrimConstant = std::make_shared<Primitive>("Constant");
inline const PrimitivePtr kPrimLocalResponseNormalization = std::make_shared<Primitive>("LocalResponseNormalization");
inline const PrimitivePtr kPrimFftReal = std::make_shared<Primitive>("FftReal");
inline const PrimitivePtr kPrimMfcc = std::make_shared<Primitive>("Mfcc");
inline const PrimitivePtr kPrimRfft = std::make_shared<Primitive>("Rfft");
inline const PrimitivePtr kPrimFftImag = std::make_shared<Primitive>("FftImag");
inline const PrimitivePtr kPrimSkipGram = std::make_shared<Primitive>("SkipGram");
inline const PrimitivePtr kPrimConv2DFusion = std::make_shared<Primitive>("Conv2DFusion");
inline const PrimitivePtr kPrimConv2dTransposeFusion = std::make_shared<Primitive>("Conv2dTransposeFusion");
inline const PrimitivePtr kPrimDepthWiseConv2DFusion = std::make_shared<Primitive>("DepthWiseConv2DFusion");
inline const PrimitivePtr kPrimAddFusion = std::make_shared<Primitive>("AddFusion");
inline const PrimitivePtr kPrimScaleFusion = std::make_shared<Primitive>("ScaleFusion");
inline const PrimitivePtr kPrimSubFusion = std::make_shared<Primitive>("SubFusion");
inline const PrimitivePtr kPrimMulFusion = std::make_shared<Primitive>("MulFusion");
inline const PrimitivePtr kPrimSigmoid = std::make_shared<Primitive>("Sigmoid");
inline const PrimitivePtr kPrimClip = std::make_shared<Primitive>("Clip");
inline const PrimitivePtr kPrimHardTanh = std::make_shared<Primitive>("HardTanh");
inline const PrimitivePtr kPrimDepthWiseConv2DTransposeFusion =
std::make_shared<Primitive>("DepthWiseConv2DTransposeFusion");
inline const PrimitivePtr kPrimArgMinFusion = std::make_shared<Primitive>("ArgMinFusion");
inline const PrimitivePtr kPrimArgMaxFusion = std::make_shared<Primitive>("ArgMaxFusion");
inline const PrimitivePtr kPrimSpaceToDepth = std::make_shared<Primitive>("SpaceToDepth");
inline const PrimitivePtr kPrimPadFusion = std::make_shared<Primitive>("PadFusion");
inline const PrimitivePtr kPrimPowFusion = std::make_shared<Primitive>("PowFusion");
inline const PrimitivePtr kPrimResize = std::make_shared<Primitive>("Resize");
inline const PrimitivePtr kPrimConv2dTranspose = std::make_shared<Primitive>("Conv2dTranspose");
inline const PrimitivePtr kPrimArgMinWithValue = std::make_shared<Primitive>("ArgMinWithValue");
inline const PrimitivePtr kPrimIf = std::make_shared<Primitive>("If");
inline const PrimitivePtr kPrimAvgPoolFusion = std::make_shared<Primitive>("AvgPoolFusion");
inline const PrimitivePtr kPrimMaxPoolFusion = std::make_shared<Primitive>("MaxPoolFusion");
inline const PrimitivePtr kPrimActivation = std::make_shared<Primitive>("Activation");
inline const PrimitivePtr kPrimTopKFusion = std::make_shared<Primitive>("TopKFusion");
inline const PrimitivePtr kPrimTileFusion = std::make_shared<Primitive>("TileFusion");
inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("ReduceFusion");
inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("LayerNormFusion");
class DoSignaturePrimitive : public Primitive {
public:
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)

View File

@ -1,19 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/abs.h"
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameAbs, Abs);
} // namespace mindspore

View File

@ -1,51 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/apply_momentum.h"
#include "c_ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
void ApplyMomentum::Init(bool use_nesterov, bool use_locking, float gradient_scale) {
this->set_use_nesterov(use_nesterov);
this->set_use_locking(use_locking);
this->set_gradient_scale(gradient_scale);
}
void ApplyMomentum::set_use_nesterov(bool use_nesterov) { this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); }
void ApplyMomentum::set_use_locking(bool use_locking) { this->AddAttr(kUseLocking, MakeValue(use_locking)); }
void ApplyMomentum::set_gradient_scale(float gradient_scale) {
this->AddAttr(kGradientScale, MakeValue(gradient_scale));
}
bool ApplyMomentum::get_use_nesterov() const {
auto value_ptr = GetAttr(kUseNesterov);
return GetValue<bool>(value_ptr);
}
bool ApplyMomentum::get_use_locking() const {
auto value_ptr = GetAttr(kUseLocking);
return GetValue<bool>(value_ptr);
}
float ApplyMomentum::get_gradient_scale() {
auto value_ptr = GetAttr(kGradientScale);
return GetValue<float>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum);
} // namespace mindspore

View File

@ -1,54 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/audio_spectrogram.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "c_ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
void AudioSpectrogram::set_window_size(const int64_t &window_size) {
this->AddAttr(kWindowSize, MakeValue(window_size));
}
int64_t AudioSpectrogram::get_window_size() const {
auto value_ptr = GetAttr(kWindowSize);
return GetValue<int64_t>(value_ptr);
}
void AudioSpectrogram::set_stride(const int64_t &stride) { this->AddAttr(kStride, MakeValue(stride)); }
int64_t AudioSpectrogram::get_stride() const {
auto value_ptr = GetAttr(kStride);
return GetValue<int64_t>(value_ptr);
}
void AudioSpectrogram::set_mag_square(const bool &mag_square) { this->AddAttr(kMagSquare, MakeValue(mag_square)); }
bool AudioSpectrogram::get_mag_square() const {
auto value_ptr = GetAttr(kMagSquare);
return GetValue<bool>(value_ptr);
}
void AudioSpectrogram::Init(const int64_t &window_size, const int64_t &stride, const bool &mag_square) {
this->set_window_size(window_size);
this->set_stride(stride);
this->set_mag_square(mag_square);
}
REGISTER_PRIMITIVE_C(kNameAudioSpectrogram, AudioSpectrogram);
} // namespace mindspore

View File

@ -1,59 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "c_ops/batch_norm.h"
#include "abstract/primitive_infer_map.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
void BatchNorm::Init(bool is_training, float epsilon, const Format &format) {
set_is_training(is_training);
set_epsilon(epsilon);
set_format(format);
}
void BatchNorm::set_is_training(bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); }
void BatchNorm::set_epsilon(float epsilon) {
CheckAndConvertUtils::CheckInRange(kEpsilon, epsilon, kIncludeBoth, {0.0, 1.0}, this->name());
this->AddAttr(kEpsilon, MakeValue(epsilon));
}
void BatchNorm::set_format(const Format &format) {
int64_t f = format;
this->AddAttr(kFormat, MakeValue(f));
}
bool BatchNorm::get_is_trainging() {
auto value_ptr = GetAttr(kIsTraining);
return GetValue<bool>(value_ptr);
}
float BatchNorm::get_epsilon() {
auto value_ptr = GetAttr(kEpsilon);
return GetValue<float>(value_ptr);
}
Format BatchNorm::get_format() const {
auto value_ptr = GetAttr(kFormat);
return Format(GetValue<int64_t>(value_ptr));
}
REGISTER_PRIMITIVE_C(kNameBatchNorm, BatchNorm);
} // namespace mindspore

View File

@ -1,21 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/batch_norm_fold.h"
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameBatchNormFold, BatchNormFold);
} // namespace mindspore

View File

@ -1,31 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/binary_cross_entropy_grad.h"
namespace mindspore {
void BinaryCrossEntropyGrad::Init(const std::string &reduction) { set_reduction(reduction); }
void BinaryCrossEntropyGrad::set_reduction(const std::string &reduction) {
CheckAndConvertUtils::CheckString(kReduction, reduction, {"none", "mean", "sum"}, name());
this->AddAttr(kReduction, MakeValue(reduction));
}
std::string BinaryCrossEntropyGrad::get_reduction() const {
auto value_ptr = GetAttr(kReduction);
return GetValue<std::string>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameBinaryCrossEntropyGrad, BinaryCrossEntropyGrad);
} // namespace mindspore

View File

@ -1,42 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/broadcast.h"
#include "c_ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
void Broadcast::Init(int64_t root_rank, const std::string &group) {
this->set_root_rank(root_rank);
this->set_group(group);
}
void Broadcast::set_root_rank(int64_t root_rank) { this->AddAttr(kKeepProb, MakeValue(root_rank)); }
void Broadcast::set_group(const std::string &group) {
CheckAndConvertUtils::CheckString(kGroup, group, {"hccl_world_group", "hccl_world_group"}, this->name());
this->AddAttr(kGroup, MakeValue(group));
}
int64_t Broadcast::get_root_rank() {
auto value_ptr = this->GetAttr(kRootRank);
return GetValue<float>(value_ptr);
}
std::string Broadcast::get_group() const {
auto value_ptr = this->GetAttr(kGroup);
return GetValue<std::string>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameBroadcast, Broadcast);
} // namespace mindspore

View File

@ -1,21 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/ceil.h"
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameCeil, Ceil);
}

View File

@ -1,21 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/cos.h"
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameCos, Cos);
}

View File

@ -1,44 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/custom_predict.h"
#include "c_ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
void CustomPredict::Init(int64_t outputNum, float weight_threshold) {
this->set_outputNum(outputNum);
this->set_weight_threshold(weight_threshold);
}
void CustomPredict::set_outputNum(int64_t outputNum) { this->AddAttr(kOutputNum, MakeValue(outputNum)); }
int64_t CustomPredict::get_outputNum() const {
auto value_ptr = this->GetAttr(kOutputNum);
return GetValue<int64_t>(value_ptr);
}
void CustomPredict::set_weight_threshold(float weight_threshold) {
this->AddAttr(kWeightThreshold, MakeValue(weight_threshold));
}
float CustomPredict::get_weight_threshold() const {
auto value_ptr = this->GetAttr(kWeightThreshold);
return GetValue<float>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameCustomPredict, CustomPredict);
} // namespace mindspore

View File

@ -1,21 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/div.h"
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameDiv, Div);
} // namespace mindspore

View File

@ -1,20 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/equal.h"
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameEqual, Equal);
} // namespace mindspore

View File

@ -1,20 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/exp.h"
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameExp, Exp);
} // namespace mindspore

View File

@ -1,44 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/fake_quant_with_min_max_vars.h"
#include "c_ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
void FakeQuantWithMinMaxVars::Init(const bool &narrow_range, int64_t num_bits) {
this->set_narrow_range(narrow_range);
this->set_num_bits(num_bits);
}
void FakeQuantWithMinMaxVars::set_narrow_range(const bool &narrow_range) {
this->AddAttr(kNarrowRange, MakeValue(narrow_range));
}
bool FakeQuantWithMinMaxVars::get_narrow_range() const {
auto value_ptr = this->GetAttr(kNarrowRange);
return GetValue<bool>(value_ptr);
}
void FakeQuantWithMinMaxVars::set_num_bits(int64_t num_bits) { this->AddAttr(kNumBits, MakeValue(num_bits)); }
int64_t FakeQuantWithMinMaxVars::get_num_bits() const {
auto value_ptr = this->GetAttr(kNumBits);
return GetValue<int64_t>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameFakeQuantWithMinMaxVars, FakeQuantWithMinMaxVars);
} // namespace mindspore

View File

@ -1,22 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/fft_imag.h"
#include <memory>
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameFftImag, FftImag);
}

View File

@ -1,21 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/flatten_grad.h"
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameFlattenGrad, FlattenGrad);
}

View File

@ -1,22 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/hashtable_lookup.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameHashtableLookup, HashtableLookup);
} // namespace mindspore

View File

@ -1,21 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/less.h"
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameLess, Less);
}

View File

@ -1,20 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/less_equal.h"
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameLessEqual, LessEqual);
} // namespace mindspore

View File

@ -1,65 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/local_response_normalization.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "c_ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
void LocalResponseNormalization::set_depth_radius(const int64_t &depth_radius) {
this->AddAttr(kDepthRadius, MakeValue(depth_radius));
}
int64_t LocalResponseNormalization::get_depth_radius() const {
auto value_ptr = GetAttr(kDepthRadius);
return GetValue<int64_t>(value_ptr);
}
void LocalResponseNormalization::set_bias(const float &bias) { this->AddAttr(kBias, MakeValue(bias)); }
float LocalResponseNormalization::get_bias() const {
auto value_ptr = GetAttr(kBias);
return GetValue<float>(value_ptr);
}
void LocalResponseNormalization::set_alpha(const float &alpha) { this->AddAttr(kAlpha, MakeValue(alpha)); }
float LocalResponseNormalization::get_alpha() const {
auto value_ptr = GetAttr(kAlpha);
return GetValue<float>(value_ptr);
}
void LocalResponseNormalization::set_beta(const float &beta) { this->AddAttr(kBeta, MakeValue(beta)); }
float LocalResponseNormalization::get_beta() const {
auto value_ptr = GetAttr(kBeta);
return GetValue<float>(value_ptr);
}
void LocalResponseNormalization::Init(const int64_t &depth_radius, const float &bias, const float &alpha,
const float &beta) {
this->set_depth_radius(depth_radius);
this->set_bias(bias);
this->set_alpha(alpha);
this->set_beta(beta);
}
REGISTER_PRIMITIVE_C(kNameLocalResponseNormalization, LocalResponseNormalization);
} // namespace mindspore

View File

@ -1,21 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/log.h"
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameLog, Log);
}

View File

@ -1,20 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/logical_not.h"
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameLogicalNot, LogicalNot);
} // namespace mindspore

View File

@ -1,20 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/logical_or.h"
namespace mindspore {
REGISTER_PRIMITIVE_C(kNameLogicalOr, LogicalOr);
} // namespace mindspore

View File

@ -1,82 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/lstm.h"
namespace mindspore {
void LSTM::set_input_size(const int64_t &input_size) {
CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name());
AddAttr(kInput_size, MakeValue(input_size));
}
int64_t LSTM::get_input_size() const {
auto value_ptr = this->GetAttr(kInput_size);
return GetValue<int64_t>(value_ptr);
}
void LSTM::set_hidden_size(const int64_t &hidden_size) {
CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name());
AddAttr(kHidden_size, MakeValue(hidden_size));
}
int64_t LSTM::get_hidden_size() const {
auto value_ptr = this->GetAttr(kHidden_size);
return GetValue<int64_t>(value_ptr);
}
void LSTM::set_num_layers(const int64_t &num_layers) {
CheckAndConvertUtils::CheckInteger(kNum_layers, num_layers, kGreaterThan, 0, this->name());
AddAttr(kNum_layers, MakeValue(kNum_layers));
}
int64_t LSTM::get_num_layers() const {
auto value_ptr = this->GetAttr(kNum_layers);
return GetValue<int64_t>(value_ptr);
}
void LSTM::set_has_bias(const bool &has_bias) { AddAttr(kHasBias, MakeValue(has_bias)); }
bool LSTM::get_has_bias() const {
auto value_ptr = this->GetAttr(kHasBias);
return GetValue<bool>(value_ptr);
}
void LSTM::set_dropout(const float &dropout) {
CheckAndConvertUtils::CheckInRange(kDropout, dropout, kIncludeBoth, {0, 1}, this->name());
AddAttr(kDropout, MakeValue(dropout));
}
float LSTM::get_dropout() const {
auto value_ptr = this->GetAttr(kDropout);
return GetValue<float>(value_ptr);
}
void LSTM::set_bidirectional(const bool &bidirectional) { AddAttr(kBidirectional, MakeValue(bidirectional)); }
bool LSTM::get_bidirectional() const {
auto value_ptr = this->GetAttr(kBidirectional);
return GetValue<bool>(value_ptr);
}
void LSTM::set_num_directions(const int64_t &num_directions) { AddAttr(kNumDirections, MakeValue(num_directions)); }
int64_t LSTM::get_num_directions() const {
auto value_ptr = this->GetAttr(kNumDirections);
return GetValue<int64_t>(value_ptr);
}
void LSTM::Init(const int64_t &input_size, const int64_t &hidden_size, const int64_t &num_layers, const bool &has_bias,
const float &dropout, const bool &bidirectional) {
this->set_input_size(input_size);
this->set_hidden_size(hidden_size);
this->set_num_layers(num_layers);
this->set_has_bias(has_bias);
this->set_dropout(dropout);
this->set_bidirectional(bidirectional);
if (bidirectional) {
this->set_num_directions(2);
} else {
this->set_num_directions(1);
}
}
REGISTER_PRIMITIVE_C(kNameLSTM, LSTM);
} // namespace mindspore

View File

@ -25,7 +25,7 @@
#include <utility>
#include "ir/tensor.h"
#include "ir/param_info.h"
#include "c_ops/primitive_c.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/log_adapter.h"
#include "utils/shape_utils.h"
@ -676,7 +676,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
const std::string &node_type = node_proto.op_type();
std::shared_ptr<Primitive> prim;
auto op_primc_fns = OpPrimCRegister::GetInstance().GetPrimCMap();
auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
prim = op_primc_fns[node_type]();
} else {

62
mindspore/core/ops/abs.cc Normal file
View File

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/abs.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <map>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto abs_prim = primitive->cast<PrimAbsPtr>();
MS_EXCEPTION_IF_NULL(abs_prim);
auto prim_name = abs_prim->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
return std::make_shared<abstract::Shape>(in_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
types.emplace("input_x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
}
} // namespace
AbstractBasePtr AbsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Abs, prim::kPrimAbs, AbsInfer);
REGISTER_PRIMITIVE_C(kNameAbs, Abs);
} // namespace ops
} // namespace mindspore

View File

@ -14,13 +14,17 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_ABS_H_
#define MINDSPORE_CORE_C_OPS_ABS_H_
#include "c_ops/primitive_c.h"
#ifndef MINDSPORE_CORE_OPS_ABS_H_
#define MINDSPORE_CORE_OPS_ABS_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAbs = "Abs";
class Abs : public PrimitiveC {
public:
@ -29,6 +33,10 @@ class Abs : public PrimitiveC {
MS_DECLARE_PARENT(Abs, PrimitiveC);
void Init() {}
};
AbstractBasePtr AbsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAbsPtr = std::shared_ptr<Abs>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_ABS_H_
#endif // MINDSPORE_CORE_OPS_ABS_H_

View File

@ -0,0 +1,88 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/adam.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto Adam_prim = primitive->cast<PrimAdamPtr>();
MS_EXCEPTION_IF_NULL(Adam_prim);
auto prim_name = Adam_prim->name();
// infer shape
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShape("var_shape", input_args[0]->GetShapeTrack(), prim_name);
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShape("m_shape", input_args[1]->GetShapeTrack(), prim_name);
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[2]->GetShapeTrack(), prim_name);
auto grad_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("grad_shape", input_args[9]->GetShapeTrack(), prim_name);
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name);
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name);
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name);
// infer type
auto var_type = input_args[0]->BuildType();
auto m_type = input_args[1]->BuildType();
auto v_type = input_args[2]->BuildType();
auto grad_type = input_args[9]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("grad_type", grad_type, common_valid_types, prim_name);
auto infer_var_type = var_type->cast<TensorTypePtr>()->element();
auto infer_m_type = m_type->cast<TensorTypePtr>()->element();
auto infer_v_type = v_type->cast<TensorTypePtr>()->element();
// auto infer_grad_type = grad_type->cast<TensorTypePtr>()->element();
auto output0 = std::make_shared<abstract::AbstractTensor>(infer_var_type, var_shape);
auto output1 = std::make_shared<abstract::AbstractTensor>(infer_m_type, m_shape);
auto output2 = std::make_shared<abstract::AbstractTensor>(infer_v_type, v_shape);
AbstractBasePtrList output = {output0, output1, output2};
return std::make_shared<abstract::AbstractTuple>(output);
}
} // namespace
void Adam::Init(const bool use_locking, const bool use_nesterov) {
this->set_use_locking(use_locking);
this->set_use_nesterov(use_nesterov);
}
void Adam::set_use_locking(const bool use_locking) { this->AddAttr(kUseLocking, MakeValue(use_locking)); }
void Adam::set_use_nesterov(const bool use_nesterov) { this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); }
bool Adam::get_use_locking() const {
auto value_ptr = GetAttr(kUseLocking);
return GetValue<bool>(value_ptr);
}
bool Adam::get_use_nesterov() const {
auto value_ptr = GetAttr(kUseNesterov);
return GetValue<bool>(value_ptr);
}
AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(AdamInfer(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(Adam, prim::kPrimAdam, AdamInfer);
REGISTER_PRIMITIVE_C(kNameAdam, Adam);
} // namespace ops
} // namespace mindspore

View File

@ -14,29 +14,34 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_ADAM_H_
#define MINDSPORE_CORE_C_OPS_ADAM_H_
#ifndef MINDSPORE_CORE_OPS_ADAM_H_
#define MINDSPORE_CORE_OPS_ADAM_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "c_ops/primitive_c.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAdam = "Adam";
class Adam : public PrimitiveC {
public:
Adam() : PrimitiveC(kNameAdam) {}
~Adam() = default;
MS_DECLARE_PARENT(Adam, PrimitiveC);
void Init(const bool &use_locking = false, const bool &use_nesteroy = false);
void set_use_locking(const bool &use_locking);
void set_use_nesteroy(const bool &use_nesteroy);
void Init(const bool use_locking = false, const bool use_nesterov = false);
void set_use_locking(const bool use_locking);
void set_use_nesterov(const bool use_nesterov);
bool get_use_locking() const;
bool get_use_nesteroy() const;
bool get_use_nesterov() const;
};
AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAdamPtr = std::shared_ptr<Adam>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_ADAM_H_
#endif // MINDSPORE_CORE_OPS_ADAM_H_

56
mindspore/core/ops/add.cc Normal file
View File

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/add.h"
#include <algorithm>
#include <memory>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto add_prim = primitive->cast<PrimAddPtr>();
MS_EXCEPTION_IF_NULL(add_prim);
auto prim_name = add_prim->name();
return BroadCastInferShape(prim_name, input_args);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
}
} // namespace
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimAdd, AddInfer);
REGISTER_PRIMITIVE_C(kNameAdd, Add);
} // namespace ops
} // namespace mindspore

View File

@ -14,21 +14,23 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_ADD_H_
#define MINDSPORE_CORE_C_OPS_ADD_H_
#ifndef MINDSPORE_CORE_OPS_ADD_H_
#define MINDSPORE_CORE_OPS_ADD_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "c_ops/primitive_c.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAdd = "Add";
class Add : public PrimitiveC {
public:
Add() : PrimitiveC(kNameAdd) { InitIOName({"x", "y"}, {"output"}); }
explicit Add(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x", "y"}, {"output"}); }
~Add() = default;
MS_DECLARE_PARENT(Add, PrimitiveC);
void Init() {}
@ -37,6 +39,7 @@ class Add : public PrimitiveC {
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAddPtr = std::shared_ptr<Add>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_ADD_H_
#endif // MINDSPORE_CORE_OPS_ADD_H_

View File

@ -14,8 +14,10 @@
* limitations under the License.
*/
#include "c_ops/add_fold.h"
#include "ops/add_fold.h"
namespace mindspore {
namespace ops {
REGISTER_PRIMITIVE_C(kNameAddFold, AddFold);
} // namespace ops
} // namespace mindspore

View File

@ -14,17 +14,18 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_ADDFOLD_H_
#define MINDSPORE_CORE_C_OPS_ADDFOLD_H_
#ifndef MINDSPORE_CORE_OPS_ADD_FOLD_H_
#define MINDSPORE_CORE_OPS_ADD_FOLD_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "c_ops/primitive_c.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAddFold = "AddFold";
class AddFold : public PrimitiveC {
public:
@ -33,6 +34,7 @@ class AddFold : public PrimitiveC {
MS_DECLARE_PARENT(AddFold, PrimitiveC);
void Init() {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_ADDFOLD_H_
#endif // MINDSPORE_CORE_OPS_ADD_FOLD_H_

108
mindspore/core/ops/adder.cc Normal file
View File

@ -0,0 +1,108 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/adder.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
void Adder::Init(const int64_t in_channel, const int64_t out_channel, const std::vector<int64_t> &kernel_size,
const PadMode &pad_mode, const std::vector<int64_t> &stride, const std::vector<int64_t> &pad_list,
const std::vector<int64_t> &dilation, const int64_t group, const Format &format) {
set_in_channel(in_channel);
set_out_channel(out_channel);
set_kernel_size(kernel_size);
set_pad_mode(pad_mode);
set_stride(stride);
set_pad_list(pad_list);
set_dilation(dilation);
set_group(group);
set_format(format);
}
void Adder::set_in_channel(const int64_t in_channel) { this->AddAttr(kInChannel, MakeValue(in_channel)); }
int64_t Adder::get_in_channel() const {
auto value_ptr = GetAttr(kInChannel);
return GetValue<int64_t>(value_ptr);
}
void Adder::set_out_channel(const int64_t out_channel) { this->AddAttr(kOutChannel, MakeValue(out_channel)); }
int64_t Adder::get_out_channel() const {
auto value_ptr = GetAttr(kOutChannel);
return GetValue<int64_t>(value_ptr);
}
void Adder::set_kernel_size(const std::vector<int64_t> &kernel_size) {
this->AddAttr(kKernelSize, MakeValue(kernel_size));
}
std::vector<int64_t> Adder::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize);
return GetValue<std::vector<int64_t>>(value_ptr);
}
void Adder::set_pad_mode(const PadMode &pad_mode) {
int64_t swi = pad_mode;
this->AddAttr(kPadMode, MakeValue(swi));
}
PadMode Adder::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode);
return PadMode(GetValue<int64_t>(value_ptr));
}
void Adder::set_stride(const std::vector<int64_t> &stride) { this->AddAttr(kStride, MakeValue(stride)); }
std::vector<int64_t> Adder::get_stride() const {
auto value_ptr = GetAttr(kStride);
return GetValue<std::vector<int64_t>>(value_ptr);
}
void Adder::set_pad_list(const std::vector<int64_t> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); }
std::vector<int64_t> Adder::get_pad_list() const {
auto value_ptr = GetAttr(kPadList);
return GetValue<std::vector<int64_t>>(value_ptr);
}
void Adder::set_dilation(const std::vector<int64_t> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); }
std::vector<int64_t> Adder::get_dilation() const {
auto value_ptr = GetAttr(kDilation);
return GetValue<std::vector<int64_t>>(value_ptr);
}
void Adder::set_group(const int64_t group) { this->AddAttr(kGroup, MakeValue(group)); }
int64_t Adder::get_group() const {
auto value_ptr = GetAttr(kGroup);
return GetValue<int64_t>(value_ptr);
}
void Adder::set_format(const Format &format) {
int64_t swi = format;
this->AddAttr(kFormat, MakeValue(swi));
}
Format Adder::get_format() const {
auto value_ptr = GetAttr(kFormat);
return Format(GetValue<int64_t>(value_ptr));
}
REGISTER_PRIMITIVE_C(kNameAdder, Adder);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_ADDER_H_
#define MINDSPORE_CORE_OPS_ADDER_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAdder = "Adder";
class Adder : public PrimitiveC {
public:
explicit Adder(const std::string &k_name = kNameAdder) : PrimitiveC(k_name) {}
~Adder() = default;
MS_DECLARE_PARENT(Adder, PrimitiveC);
void Init(const int64_t in_channel, const int64_t out_channel, const std::vector<int64_t> &kernel_size,
const PadMode &pad_mode, const std::vector<int64_t> &stride, const std::vector<int64_t> &pad_list,
const std::vector<int64_t> &dilation, const int64_t group, const Format &format);
void set_in_channel(const int64_t in_channel);
void set_out_channel(const int64_t out_channel);
void set_kernel_size(const std::vector<int64_t> &kernel_size);
void set_pad_mode(const PadMode &pad_mode);
void set_stride(const std::vector<int64_t> &stride);
void set_pad_list(const std::vector<int64_t> &pad_list);
void set_dilation(const std::vector<int64_t> &dilation);
void set_group(const int64_t group);
void set_format(const Format &format);
int64_t get_in_channel() const;
int64_t get_out_channel() const;
std::vector<int64_t> get_kernel_size() const;
PadMode get_pad_mode() const;
std::vector<int64_t> get_stride() const;
std::vector<int64_t> get_pad_list() const;
std::vector<int64_t> get_dilation() const;
int64_t get_group() const;
Format get_format() const;
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_ADDER_H_

View File

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <string>
#include <vector>
#include <map>
#include <memory>
#include "ops/addn.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto input_tuple = input_args[0]->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(input_tuple);
auto elements = input_tuple->elements();
CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name);
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(element0);
auto element0_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name);
std::map<std::string, TypePtr> types;
types.emplace("element0", element0->BuildType());
for (size_t i = 1; i < elements.size(); ++i) {
std::string elementi = "element" + std::to_string(i);
auto elementi_shape =
CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(),
prim_name);
for (size_t j = 0; j < element0_shape.size(); ++j) {
if (elementi_shape[j] != element0_shape[j]) {
MS_LOG(EXCEPTION) << "element " << i << " shape in input can not concat with first element.";
}
}
types.emplace(elementi, elements[i]->BuildType());
}
std::set<TypeId> valid_types = common_valid_types;
valid_types.insert(kNumberTypeBool);
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(infer_type),
std::make_shared<abstract::Shape>(element0_shape));
}
REGISTER_PRIMITIVE_EVAL_IMPL(AddN, prim::kPrimAddN, AddNInfer);
REGISTER_PRIMITIVE_C(kNameAddN, AddN);
} // namespace ops
} // namespace mindspore

View File

@ -14,13 +14,16 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_ADDN_H_
#define MINDSPORE_CORE_C_OPS_ADDN_H_
#include "c_ops/primitive_c.h"
#ifndef MINDSPORE_CORE_OPS_ADDN_H_
#define MINDSPORE_CORE_OPS_ADDN_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAddN = "AddN";
class AddN : public PrimitiveC {
public:
@ -29,6 +32,10 @@ class AddN : public PrimitiveC {
MS_DECLARE_PARENT(AddN, PrimitiveC);
void Init() {}
};
AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAddNPtr = std::shared_ptr<AddN>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_ADDN_H_
#endif // MINDSPORE_CORE_OPS_ADDN_H_

View File

@ -14,19 +14,20 @@
* limitations under the License.
*/
#include "c_ops/concat.h"
#include "c_ops/op_utils.h"
#include "ops/all.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
void Concat::Init(int64_t axis) { this->set_axis(axis); }
int64_t Concat::get_axis() const {
auto value_ptr = this->GetAttr(kAxis);
namespace mindspore {
namespace ops {
void All::Init(const int64_t keep_dims) { this->set_keep_dims(keep_dims); }
void All::set_keep_dims(const int64_t keep_dims) { this->AddAttr(kKeepDims, MakeValue(keep_dims)); }
int64_t All::get_keep_dims() const {
auto value_ptr = GetAttr(kKeepDims);
return GetValue<int64_t>(value_ptr);
}
void Concat::set_axis(int64_t axis) {
this->AddAttr(kAxis, MakeValue(CheckAndConvertUtils::CheckInteger(kAxis, axis, kGreaterEqual, 0, this->name())));
}
REGISTER_PRIMITIVE_C(kNameConcat, Concat);
REGISTER_PRIMITIVE_C(kNameAll, All);
} // namespace ops
} // namespace mindspore

38
mindspore/core/ops/all.h Normal file
View File

@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_ALL_H_
#define MINDSPORE_CORE_OPS_ALL_H_
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAll = "All";
class All : public PrimitiveC {
public:
All() : PrimitiveC(kNameAll) {}
~All() = default;
MS_DECLARE_PARENT(All, PrimitiveC);
void Init(const int64_t keep_dims);
void set_keep_dims(const int64_t keep_dims);
int64_t get_keep_dims() const;
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_ALL_H_

View File

@ -0,0 +1,89 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include <set>
#include <map>
#include <string>
#include "ops/apply_momentum.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
void ApplyMomentum::Init(const bool use_nesterov, const bool use_locking, const float gradient_scale) {
this->set_use_nesterov(use_nesterov);
this->set_use_locking(use_locking);
this->set_gradient_scale(gradient_scale);
}
void ApplyMomentum::set_use_nesterov(const bool use_nesterov) { this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); }
void ApplyMomentum::set_use_locking(const bool use_locking) { this->AddAttr(kUseLocking, MakeValue(use_locking)); }
void ApplyMomentum::set_gradient_scale(const float gradient_scale) {
this->AddAttr(kGradientScale, MakeValue(gradient_scale));
}
bool ApplyMomentum::get_use_nesterov() const {
auto value_ptr = GetAttr(kUseNesterov);
return GetValue<bool>(value_ptr);
}
bool ApplyMomentum::get_use_locking() const {
auto value_ptr = GetAttr(kUseLocking);
return GetValue<bool>(value_ptr);
}
float ApplyMomentum::get_gradient_scale() const {
auto value_ptr = GetAttr(kGradientScale);
return GetValue<float>(value_ptr);
}
AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto momentum_prim = primitive->cast<PrimApplyMomentumPtr>();
MS_EXCEPTION_IF_NULL(momentum_prim);
auto prim_name = momentum_prim->name();
CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name);
// Infer shape
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name);
// Infer type
auto v_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
auto a_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
auto l_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();
auto g_type = input_args[3]->BuildType()->cast<TensorTypePtr>()->element();
auto m_type = input_args[4]->BuildType()->cast<TensorTypePtr>()->element();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_type, valid_types, prim_name);
const std::set<TypePtr> valid_types_ptr = {TypeIdToType(kNumberTypeFloat16), TypeIdToType(kNumberTypeFloat32),
TypeIdToType(kNumberTypeFloat64)};
std::map<std::string, TypePtr> args;
args.insert({"l_type", l_type});
args.insert({"g_type", g_type});
args.insert({"m_type", m_type});
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types_ptr, prim_name);
return std::make_shared<abstract::AbstractTensor>(g_type, v_shape);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyMomentum, prim::kPrimApplyMomentum, ApplyMomentumInfer);
REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum);
} // namespace ops
} // namespace mindspore

View File

@ -14,13 +14,17 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_APPLYMOMENTUM_H_
#define MINDSPORE_CORE_C_OPS_APPLYMOMENTUM_H_
#include "c_ops/primitive_c.h"
#ifndef MINDSPORE_CORE_OPS_APPLY_MOMENTUM_H_
#define MINDSPORE_CORE_OPS_APPLY_MOMENTUM_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyMomentum = "ApplyMomentum";
class ApplyMomentum : public PrimitiveC {
public:
@ -29,14 +33,18 @@ class ApplyMomentum : public PrimitiveC {
}
~ApplyMomentum() = default;
MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC);
void Init(bool use_nesterov, bool use_locking, float gradient_scale);
void set_use_nesterov(bool use_nesterov);
void set_use_locking(bool use_locking);
void set_gradient_scale(float gradient_scale);
void Init(const bool use_nesterov = false, const bool use_locking = false, const float gradient_scale = 1.0);
void set_use_nesterov(const bool use_nesterov);
void set_use_locking(const bool use_locking);
void set_gradient_scale(const float gradient_scale);
bool get_use_nesterov() const;
bool get_use_locking() const;
float get_gradient_scale();
float get_gradient_scale() const;
};
AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimApplyMomentumPtr = std::shared_ptr<ApplyMomentum>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_APPLYMOMENTUM_H_
#endif // MINDSPORE_CORE_OPS_APPLY_MOMENTUM_H_

View File

@ -0,0 +1,73 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/arg_max.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto prim = primitive->cast<PrimArgMaxPtr>();
MS_EXCEPTION_IF_NULL(prim);
auto axis = prim->get_axis();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::CheckInRange<int64_t>("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name);
axis = axis < 0 ? axis + x_rank : axis;
std::vector<int64_t> out_shape;
for (size_t i = 0; i < x_shape.size(); ++i) {
if (SizeToLong(i) != axis) {
out_shape.emplace_back(x_shape[i]);
}
}
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return kInt32;
}
} // namespace
void ArgMax::Init(const int64_t axis, const TypeId output_type) {
set_axis(axis);
set_output_type(output_type);
}
void ArgMax::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); }
void ArgMax::set_output_type(const TypeId output_type) { this->AddAttr(kOutputType, TypeIdToType(output_type)); }
int64_t ArgMax::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
TypeId ArgMax::get_output_type() const {
auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element();
return type_ptr->type_id();
}
AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(ArgMax, prim::kPrimArgMax, ArgMaxInfer);
REGISTER_PRIMITIVE_C(kNameArgMax, ArgMax);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_ARG_MAX_H_
#define MINDSPORE_CORE_OPS_ARG_MAX_H_
#include <string>
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameArgMax = "Argmax";
class ArgMax : public PrimitiveC {
public:
ArgMax() : PrimitiveC(kNameArgMax) { InitIOName({"x"}, {"output"}); }
explicit ArgMax(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
~ArgMax() = default;
MS_DECLARE_PARENT(ArgMax, PrimitiveC);
void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32);
void set_axis(const int64_t axis);
void set_output_type(const TypeId output_type);
int64_t get_axis() const;
TypeId get_output_type() const;
};
AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimArgMaxPtr = std::shared_ptr<ArgMax>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_ARG_MAX_H_

View File

@ -0,0 +1,73 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include "ops/arg_min.h"
namespace mindspore {
namespace ops {
void ArgMin::Init(const int64_t axis, const TypeId output_type) {
set_axis(axis);
set_output_type(output_type);
}
void ArgMin::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); }
void ArgMin::set_output_type(const TypeId output_type) { this->AddAttr(kOutputType, TypeIdToType(output_type)); }
int64_t ArgMin::get_axis() const {
auto value_ptr = GetAttr(kAxis);
return GetValue<int64_t>(value_ptr);
}
TypeId ArgMin::get_output_type() const {
auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element();
return type_ptr->type_id();
}
AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto argmin_prim = primitive->cast<PrimArgMin>();
MS_EXCEPTION_IF_NULL(argmin_prim);
auto prim_name = argmin_prim->name();
CheckAndConvertUtils::CheckInteger("arg_min_infer", input_args.size(), kEqual, 1, prim_name);
// Infer shape
auto axis = argmin_prim->get_axis();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name);
if (axis < 0) {
axis += x_rank;
}
std::vector<int64_t> out_shape;
for (int64_t i = 0; i < x_rank; i++) {
if (i != axis) {
out_shape.push_back(x_shape[i]);
}
}
// Infer type
auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)};
CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(x_dtype, std::make_shared<abstract::Shape>(out_shape));
}
REGISTER_PRIMITIVE_EVAL_IMPL(ArgMin, prim::kPrimArgMin, ArgMinInfer);
REGISTER_PRIMITIVE_C(kNameArgMin, ArgMin);
} // namespace ops
} // namespace mindspore

View File

@ -14,32 +14,37 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_ARGMIN_H_
#define MINDSPORE_CORE_C_OPS_ARGMIN_H_
#ifndef MINDSPORE_CORE_OPS_ARG_MIN_H_
#define MINDSPORE_CORE_OPS_ARG_MIN_H_
#include <string>
#include <vector>
#include <memory>
#include "c_ops/primitive_c.h"
#include "c_ops/op_utils.h"
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameArgMin = "ArgMin";
class ArgMin : public PrimitiveC {
public:
ArgMin() : PrimitiveC(kNameArgMin) { InitIOName({"x"}, {"output"}); }
explicit ArgMin(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
~ArgMin() = default;
MS_DECLARE_PARENT(ArgMin, PrimitiveC);
void Init(bool keep_dims, int64_t axis = -1);
void set_axis(int64_t axis);
void set_keep_dims(bool keep_dims);
int64_t get_axis();
bool get_keep_dims();
void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32);
void set_axis(const int64_t axis);
void set_output_type(const TypeId output_type);
int64_t get_axis() const;
TypeId get_output_type() const;
};
AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimArgMin = std::shared_ptr<ArgMin>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_ARGMIN_H_
#endif // MINDSPORE_CORE_OPS_ARG_MIN_H_

View File

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <vector>
#include <memory>
#include "ops/asin.h"
namespace mindspore {
namespace ops {
AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto asin_prim = primitive->cast<PrimAsinPtr>();
MS_EXCEPTION_IF_NULL(asin_prim);
auto prim_name = asin_prim->name();
CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name);
// Infer Shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto infer_shape = std::make_shared<abstract::Shape>(x_shape);
// Infer Type
auto dtype = input_args[0]->BuildType();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32};
CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name);
auto tensor_type = dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();
MS_EXCEPTION_IF_NULL(element);
auto infer_type = std::make_shared<TensorType>(TypeIdToType(element->type_id()));
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Asin, prim::kPrimAsin, AsinInfer);
REGISTER_PRIMITIVE_C(kNameAsin, Asin);
} // namespace ops
} // namespace mindspore

42
mindspore/core/ops/asin.h Normal file
View File

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_ASIN_H_
#define MINDSPORE_CORE_OPS_ASIN_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAsin = "Asin";
class Asin : public PrimitiveC {
public:
Asin() : PrimitiveC(kNameAsin) {}
~Asin() = default;
MS_DECLARE_PARENT(Asin, PrimitiveC);
void Init() {}
};
AbstractBasePtr ASinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAsinPtr = std::shared_ptr<Asin>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_ASIN_H_

View File

@ -0,0 +1,78 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <map>
#include <string>
#include <set>
#include <vector>
#include <memory>
#include "ops/assert.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
void Assert::Init(const int64_t summarize) { set_summarize(summarize); }
void Assert::set_summarize(const int64_t summarize) { this->AddAttr(kSummarize, MakeValue(summarize)); }
int64_t Assert::get_summarize() const {
auto value_ptr = GetAttr(kSummarize);
return GetValue<int64_t>(value_ptr);
}
AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto Assert_prim = primitive->cast<PrimAssertPtr>();
MS_EXCEPTION_IF_NULL(Assert_prim);
auto op_name = Assert_prim->name();
TypePtr condition;
if (!(input_args[0]->BuildType()->type_id() == kObjectTypeTensorType)) {
auto condition_value = GetValue<std::vector<bool>>(input_args[0]->BuildValue());
CheckAndConvertUtils::CheckInteger("condition's rank", condition_value.size(), kLessEqual, 1, op_name);
if (condition_value.size() == 1) {
CheckAndConvertUtils::CheckInteger("condition[0]", condition_value[0], kEqual, 1, op_name);
}
condition = TypeIdToType(kNumberTypeBool);
} else {
auto condition_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
CheckAndConvertUtils::CheckInteger("condition's rank", condition_shape[0], kLessEqual, 1, op_name);
if (condition_shape[0] == 1) {
auto condition_value = reinterpret_cast<bool *>(input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c());
MS_EXCEPTION_IF_NULL(condition_value);
// auto condition_value = GetValue<bool>(input_args[0]->BuildValue());
CheckAndConvertUtils::CheckInteger("condition[0]", *condition_value, kEqual, 1, op_name);
}
condition = input_args[0]->BuildType();
}
std::vector<int64_t> output_shape = {1};
std::set<TypePtr> local_bool = {TypeIdToType(kNumberTypeBool)};
std::map<std::string, TypePtr> args = {{"condition", condition}};
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name);
auto inputs_type = input_args[1]->BuildType()->cast<TuplePtr>()->elements();
for (auto dtype : inputs_type) {
std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)};
CheckAndConvertUtils::CheckSubClass("input", dtype, template_types, op_name);
}
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(kNumberTypeInt32), output_shape);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Assert, prim::kPrimAssert, AssertInfer);
REGISTER_PRIMITIVE_C(kNameAssert, Assert);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_ASSERT_H_
#define MINDSPORE_CORE_OPS_ASSERT_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAssert = "Assert";
class Assert : public PrimitiveC {
public:
Assert() : PrimitiveC(kNameAssert) {}
~Assert() = default;
MS_DECLARE_PARENT(Assert, PrimitiveC);
void Init(const int64_t summarize = 3);
void set_summarize(const int64_t summarize);
int64_t get_summarize() const;
};
AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAssertPtr = std::shared_ptr<Assert>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_ASSERT_H_

View File

@ -14,8 +14,18 @@
* limitations under the License.
*/
#include "c_ops/assign.h"
#include <set>
#include <map>
#include <vector>
#include <memory>
#include <string>
#include "ops/assign.h"
#include "ops/op_utils.h"
#include "ir/dtype/ref.h"
namespace mindspore {
namespace ops {
REGISTER_PRIMITIVE_C(kNameAssign, Assign);
} // namespace ops
} // namespace mindspore

View File

@ -14,13 +14,17 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_ASSIGN_H_
#define MINDSPORE_CORE_C_OPS_ASSIGN_H_
#include "c_ops/primitive_c.h"
#ifndef MINDSPORE_CORE_OPS_ASSIGN_H_
#define MINDSPORE_CORE_OPS_ASSIGN_H_
#include <memory>
#include <vector>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAssign = "Assign";
class Assign : public PrimitiveC {
public:
@ -29,6 +33,9 @@ class Assign : public PrimitiveC {
MS_DECLARE_PARENT(Assign, PrimitiveC);
void Init() {}
};
using PrimAssignPtr = std::shared_ptr<Assign>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_ASSIGN_H_
#endif // MINDSPORE_CORE_OPS_ASSIGN_H_

View File

@ -0,0 +1,53 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <map>
#include <string>
#include "ops/assign_add.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto assignadd_prim = primitive->cast<PrimAssignAddPtr>();
MS_EXCEPTION_IF_NULL(assignadd_prim);
auto prim_name = assignadd_prim->name();
auto value_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("value_shape", input_args[1]->BuildShape(), prim_name);
return std::make_shared<abstract::Shape>(value_shape);
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("w", input_args[1]->BuildType());
// check_scalar_or_tensor_types_same
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd");
return TypeIdToType(infer_type);
}
} // namespace
AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(AssignAdd, prim::kPrimAssignAdd, AssignAddInfer);
REGISTER_PRIMITIVE_C(kNameAssignAdd, AssignAdd);
} // namespace ops
} // namespace mindspore

View File

@ -14,13 +14,17 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_ASSIGNADD_H_
#define MINDSPORE_CORE_C_OPS_ASSIGNADD_H_
#include "c_ops/primitive_c.h"
#ifndef MINDSPORE_CORE_OPS_ASSIGN_ADD_H_
#define MINDSPORE_CORE_OPS_ASSIGN_ADD_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAssignAdd = "AssignAdd";
class AssignAdd : public PrimitiveC {
public:
@ -29,6 +33,10 @@ class AssignAdd : public PrimitiveC {
MS_DECLARE_PARENT(AssignAdd, PrimitiveC);
void Init() {}
};
AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAssignAddPtr = std::shared_ptr<AssignAdd>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_ASSIGNADD_H_
#endif // MINDSPORE_CORE_OPS_ASSIGN_ADD_H_

View File

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include "ops/atan.h"
namespace mindspore {
namespace ops {
AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto atan_prim = primitive->cast<PrimAtanPtr>();
MS_EXCEPTION_IF_NULL(atan_prim);
auto prim_name = atan_prim->name();
CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name);
// Infer Shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto infer_shape = std::make_shared<abstract::Shape>(x_shape);
// Infer Type
auto dtype = input_args[0]->BuildType();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32};
CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name);
auto tensor_type = dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();
MS_EXCEPTION_IF_NULL(element);
auto infer_type = std::make_shared<TensorType>(TypeIdToType(element->type_id()));
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Atan, prim::kPrimAtan, AtanInfer);
REGISTER_PRIMITIVE_C(kNameAtan, Atan);
} // namespace ops
} // namespace mindspore

View File

@ -14,17 +14,18 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_ATAN_H_
#define MINDSPORE_CORE_C_OPS_ATAN_H_
#ifndef MINDSPORE_CORE_OPS_ATAN_H_
#define MINDSPORE_CORE_OPS_ATAN_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "c_ops/primitive_c.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAtan = "Atan";
class Atan : public PrimitiveC {
public:
@ -33,6 +34,10 @@ class Atan : public PrimitiveC {
MS_DECLARE_PARENT(Atan, PrimitiveC);
void Init() {}
};
AbstractBasePtr ATanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAtanPtr = std::shared_ptr<Atan>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_ATAN_H_
#endif // MINDSPORE_CORE_OPS_ATAN_H_

View File

@ -0,0 +1,125 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/audio_spectrogram.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto audio_spectrogram_prim = primitive->cast<PrimAudioSpectrogramPtr>();
MS_EXCEPTION_IF_NULL(audio_spectrogram_prim);
auto prim_name = audio_spectrogram_prim->name();
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
if (input_shape.size() != 2) {
MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions";
}
if (audio_spectrogram_prim->get_window_size() < 2) {
MS_LOG(ERROR) << "window size is too short, now is " << audio_spectrogram_prim->get_window_size();
}
if (audio_spectrogram_prim->get_stride() < 1) {
MS_LOG(ERROR) << "stride must be positive, now is " << audio_spectrogram_prim->get_stride();
}
std::vector<int64_t> infer_shape;
infer_shape.push_back(input_shape[1]);
int64_t sample_sub_window = input_shape[0] - audio_spectrogram_prim->get_window_size();
infer_shape.push_back(sample_sub_window < 0 ? 0 : 1 + sample_sub_window / audio_spectrogram_prim->get_stride());
int64_t fft_length = audio_spectrogram_prim->GetFftLength(audio_spectrogram_prim->get_window_size());
infer_shape.push_back(fft_length / 2 + 1);
MS_LOG(ERROR) << infer_shape;
return std::make_shared<abstract::Shape>(infer_shape);
}
TypePtr AudioSpectrogramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = input_args[0]->BuildType();
auto tensor_type = infer_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto data_type = tensor_type->element();
MS_EXCEPTION_IF_NULL(data_type);
return data_type;
}
} // namespace
void AudioSpectrogram::set_window_size(const int64_t window_size) {
this->AddAttr(kWindowSize, MakeValue(window_size));
}
int64_t AudioSpectrogram::get_window_size() const {
auto value_ptr = GetAttr(kWindowSize);
return GetValue<int64_t>(value_ptr);
}
void AudioSpectrogram::set_stride(const int64_t stride) { this->AddAttr(kStride, MakeValue(stride)); }
int64_t AudioSpectrogram::get_stride() const {
auto value_ptr = GetAttr(kStride);
return GetValue<int64_t>(value_ptr);
}
int64_t AudioSpectrogram::Log2Ceil(int64_t length) {
if (length == 0) {
return -1;
}
int64_t floor = 0;
for (int64_t i = 4; i >= 0; --i) {
const int64_t shift = (int64_t)(1 << i);
int64_t tmp = length >> shift;
if (tmp != 0) {
length = tmp;
floor += shift;
}
}
return length == (length & ~(length - 1)) ? floor : floor + 1;
}
int64_t AudioSpectrogram::GetFftLength(int64_t length) {
int64_t shift = Log2Ceil(length);
return 1 << shift;
}
void AudioSpectrogram::set_mag_square(const bool mag_square) { this->AddAttr(kMagSquare, MakeValue(mag_square)); }
bool AudioSpectrogram::get_mag_square() const {
auto value_ptr = GetAttr(kMagSquare);
return GetValue<bool>(value_ptr);
}
void AudioSpectrogram::Init(const int64_t window_size, const int64_t stride, const bool mag_square) {
this->set_window_size(window_size);
this->set_stride(stride);
this->set_mag_square(mag_square);
}
AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(AudioSpectrogramInferType(primitive, input_args),
AudioSpectrogramInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(AudioSpectrogram, prim::kPrimAudioSpectrogram, AudioSpectrogramInfer);
REGISTER_PRIMITIVE_C(kNameAudioSpectrogram, AudioSpectrogram);
} // namespace ops
} // namespace mindspore

View File

@ -14,31 +14,38 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_AUDIOSPECTROGRAM_H_
#define MINDSPORE_CORE_C_OPS_AUDIOSPECTROGRAM_H_
#ifndef MINDSPORE_CORE_OPS_AUDIO_SPECTROGRAM_H_
#define MINDSPORE_CORE_OPS_AUDIO_SPECTROGRAM_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "c_ops/primitive_c.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAudioSpectrogram = "AudioSpectrogram";
class AudioSpectrogram : public PrimitiveC {
public:
AudioSpectrogram() : PrimitiveC(kNameAudioSpectrogram) {}
~AudioSpectrogram() = default;
MS_DECLARE_PARENT(AudioSpectrogram, PrimitiveC);
void Init(const int64_t &window_size, const int64_t &stride, const bool &mag_square);
void set_window_size(const int64_t &window_size);
void set_stride(const int64_t &stride);
void set_mag_square(const bool &mag_square);
void Init(const int64_t window_size, const int64_t stride, const bool mag_square);
void set_window_size(const int64_t window_size);
void set_stride(const int64_t stride);
void set_mag_square(const bool mag_square);
int64_t get_window_size() const;
int64_t get_stride() const;
bool get_mag_square() const;
int64_t Log2Ceil(int64_t length);
int64_t GetFftLength(int64_t length);
};
AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAudioSpectrogramPtr = std::shared_ptr<AudioSpectrogram>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_AUDIOSPECTROGRAM_H_
#endif // MINDSPORE_CORE_OPS_AUDIO_SPECTROGRAM_H_

View File

@ -14,25 +14,26 @@
* limitations under the License.
*/
#include "c_ops/avg_pool.h"
#include "ops/avg_pool.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "c_ops/op_utils.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
void AvgPool::set_padding(const std::string &padding) {
CheckAndConvertUtils::CheckString(kPadding, padding, {kValid, kSame}, this->name());
this->AddAttr(kPadding, MakeValue(padding));
namespace ops {
void AvgPool::set_pad_mode(const PadMode &pad_mode) {
int64_t swi = pad_mode;
this->AddAttr(kPadMode, MakeValue(swi));
}
std::string AvgPool::get_padding() const {
auto value_ptr = GetAttr(kPadding);
return GetValue<std::string>(value_ptr);
PadMode AvgPool::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode);
return PadMode(GetValue<int64_t>(value_ptr));
}
void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(),
@ -44,12 +45,12 @@ std::vector<int64_t> AvgPool::get_kernel_size() const {
return GetValue<std::vector<int64_t>>(value_ptr);
}
void AvgPool::set_strides(const std::vector<int64_t> &strides) {
this->AddAttr(kStride,
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, strides, this->name(), false, true)));
this->AddAttr(kStrides,
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name(), false, true)));
}
std::vector<int64_t> AvgPool::get_strides() const {
auto value_ptr = GetAttr(kStride);
auto value_ptr = GetAttr(kStrides);
return GetValue<std::vector<int64_t>>(value_ptr);
}
@ -70,20 +71,19 @@ std::vector<int64_t> AvgPool::get_pad() const {
return GetValue<std::vector<int64_t>>(value_ptr);
}
void AvgPool::set_round_mode(const int64_t &round_mode) {
CheckAndConvertUtils::CheckInRange(kRoundMode, round_mode, kIncludeBoth, {0, 1}, this->name());
this->AddAttr(kRoundMode, MakeValue(round_mode));
void AvgPool::set_round_mode(const RoundMode &round_mode) {
int64_t swi = round_mode;
this->AddAttr(kRoundMode, MakeValue(swi));
}
int64_t AvgPool::get_round_mode() const {
RoundMode AvgPool::get_round_mode() const {
auto value_ptr = GetAttr(kRoundMode);
return GetValue<int64_t>(value_ptr);
return RoundMode(GetValue<int64_t>(value_ptr));
}
void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride,
const std::string &padding, const Format &format, const std::vector<int64_t> &pad,
const int64_t &round_mode) {
this->set_padding(padding);
void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride, const PadMode &pad_mode,
const Format &format, const std::vector<int64_t> &pad, const RoundMode &round_mode) {
this->set_pad_mode(pad_mode);
this->set_kernel_size(kernel_size);
this->set_strides(stride);
this->set_format(format);
@ -98,9 +98,12 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(pool_prim);
auto op_name = pool_prim->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
if (pool_prim->get_format() == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
}
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name);
auto kernel_size = pool_prim->get_kernel_size();
auto pad_mode = pool_prim->get_padding();
auto pad_mode = pool_prim->get_pad_mode();
auto batch = in_shape[0];
auto channel = in_shape[1];
auto in_h = in_shape[2];
@ -113,14 +116,17 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto stride_w = strides[3];
int64_t out_h = -1;
int64_t out_w = -1;
if (pad_mode == "valid") {
if (pad_mode == VALID) {
out_h = ceil((in_h - (kernel_h - 1)) / stride_h);
out_w = ceil((in_w - (kernel_w - 1)) / stride_w);
} else if (pad_mode == "same") {
} else if (pad_mode == SAME) {
out_h = ceil(in_h / stride_h);
out_w = ceil(in_w / stride_w);
}
std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
if (pool_prim->get_format() == NHWC) {
out_shape = {batch, out_h, out_w, channel};
}
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
MS_LOG(EXCEPTION) << "Kernel size is not valid.";
}
@ -142,4 +148,5 @@ AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const Primitiv
}
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool, prim::kPrimAvgPool, AvgPoolInfer);
REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool);
} // namespace ops
} // namespace mindspore

View File

@ -14,45 +14,48 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_AVG_POOL_H_
#define MINDSPORE_CORE_C_OPS_AVG_POOL_H_
#ifndef MINDSPORE_CORE_OPS_AVG_POOL_H_
#define MINDSPORE_CORE_OPS_AVG_POOL_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "c_ops/primitive_c.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAvgPool = "AvgPool";
class AvgPool : public PrimitiveC {
public:
AvgPool() : PrimitiveC(kNameAvgPool) { InitIOName({"x"}, {"output"}); }
explicit AvgPool(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
~AvgPool() = default;
MS_DECLARE_PARENT(AvgPool, PrimitiveC);
void Init(const std::vector<int64_t> &kernel_size = {1}, const std::vector<int64_t> &stride = {1},
const std::string &padding = "valid", const Format &format = NCHW,
const std::vector<int64_t> &pad = {0, 0, 0, 0}, const int64_t &round_mode = 0);
void set_padding(const std::string &padding);
const PadMode &pad_mode = VALID, const Format &format = NCHW,
const std::vector<int64_t> &pad = {0, 0, 0, 0}, const RoundMode &round_mode = FLOOR);
void set_pad_mode(const PadMode &pad_mode);
void set_kernel_size(const std::vector<int64_t> &kernel_size);
void set_strides(const std::vector<int64_t> &strides);
void set_format(const Format &format);
void set_pad(const std::vector<int64_t> &pad);
void set_round_mode(const int64_t &round_mode);
void set_round_mode(const RoundMode &round_mode);
std::vector<int64_t> get_kernel_size() const;
std::vector<int64_t> get_strides() const;
std::string get_padding() const;
PadMode get_pad_mode() const;
Format get_format() const;
std::vector<int64_t> get_pad() const;
int64_t get_round_mode() const;
RoundMode get_round_mode() const;
};
AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAvgPoolPtr = std::shared_ptr<AvgPool>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_AVG_POOL_H_
#endif // MINDSPORE_CORE_OPS_AVG_POOL_H_

View File

@ -0,0 +1,140 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/batch_norm.h"
#include "abstract/primitive_infer_map.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
void BatchNorm::Init(const bool is_training, const float epsilon, const float momentum, const Format &format) {
set_is_training(is_training);
set_epsilon(epsilon);
set_format(format);
set_momentum(momentum);
}
void BatchNorm::set_is_training(const bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); }
void BatchNorm::set_epsilon(const float epsilon) {
CheckAndConvertUtils::CheckInRange<float>(kEpsilon, epsilon, kIncludeBoth, {0.0, 1.0}, this->name());
this->AddAttr(kEpsilon, MakeValue(epsilon));
}
void BatchNorm::set_format(const Format &format) {
int64_t f = format;
this->AddAttr(kFormat, MakeValue(f));
}
void BatchNorm::set_momentum(const float momentun) {
CheckAndConvertUtils::CheckInRange<int64_t>(kMomentum, momentun, kIncludeBoth, {0.0, 1.0}, this->name());
this->AddAttr(kMomentum, MakeValue(momentun));
}
float BatchNorm::get_momentum() const {
auto value_ptr = GetAttr(kMomentum);
return GetValue<float>(value_ptr);
}
bool BatchNorm::get_is_training() const {
auto value_ptr = GetAttr(kIsTraining);
return GetValue<bool>(value_ptr);
}
float BatchNorm::get_epsilon() const {
auto value_ptr = GetAttr(kEpsilon);
return GetValue<float>(value_ptr);
}
Format BatchNorm::get_format() const {
auto value_ptr = GetAttr(kFormat);
return Format(GetValue<int64_t>(value_ptr));
}
AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
// Infer shape
MS_EXCEPTION_IF_NULL(primitive);
auto batch_prim = primitive->cast<PrimBatchNormPtr>();
MS_EXCEPTION_IF_NULL(batch_prim);
auto prim_name = batch_prim->name();
CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name);
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name);
if (batch_prim->get_format() == NHWC) {
input_x = {input_x[0], input_x[3], input_x[1], input_x[2]};
}
auto scale = CheckAndConvertUtils::ConvertShapePtrToShape("scale", input_args[1]->BuildShape(), prim_name);
auto bias = CheckAndConvertUtils::ConvertShapePtrToShape("bias", input_args[2]->BuildShape(), prim_name);
auto mean = CheckAndConvertUtils::ConvertShapePtrToShape("mean", input_args[3]->BuildShape(), prim_name);
auto variance = CheckAndConvertUtils::ConvertShapePtrToShape("variance", input_args[4]->BuildShape(), prim_name);
std::vector<int64_t> input_shape_norm;
if (batch_prim->get_format() == NCHW) {
input_shape_norm =
CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
} else {
input_shape_norm.push_back(input_x[0]);
input_shape_norm.push_back(input_x[3]);
input_shape_norm.push_back(input_x[1]);
input_shape_norm.push_back(input_x[2]);
}
CheckAndConvertUtils::CheckInteger("scale rank", scale.size(), kEqual, 1, prim_name);
CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError);
CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name,
TypeError);
if (!batch_prim->get_is_training()) {
CheckAndConvertUtils::CheckInteger("mean rank", mean.size(), kEqual, 1, prim_name);
CheckAndConvertUtils::Check("mean shape", mean, kEqual, "variance shape", variance, prim_name, TypeError);
CheckAndConvertUtils::Check("mean shape", mean, kEqual, "scale shape", scale, prim_name, TypeError);
}
// Infer type
auto input_x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
auto scale_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
auto bias_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
std::map<std::string, TypePtr> args;
args.emplace("scale", input_args[1]->BuildType());
args.emplace("bias", input_args[2]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
std::map<std::string, TypePtr> args_moving;
args_moving.emplace("scale", input_args[2]->BuildType());
args_moving.emplace("bias", input_args[3]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name);
auto output0 = std::make_shared<abstract::AbstractTensor>(input_x_type, input_x);
auto output1 = std::make_shared<abstract::AbstractTensor>(scale_type, scale);
auto output2 = std::make_shared<abstract::AbstractTensor>(bias_type, scale);
auto output3 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale);
if (batch_prim->get_format() == NHWC) {
output2 = std::make_shared<abstract::AbstractTensor>(scale_type, scale);
output3 = std::make_shared<abstract::AbstractTensor>(bias_type, scale);
output1 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale);
}
AbstractBasePtrList output = {output0, output1, output2, output3, output3};
return std::make_shared<abstract::AbstractTuple>(output);
}
REGISTER_PRIMITIVE_EVAL_IMPL(BatchNorm, prim::kPrimBatchNorm, BatchNormInfer);
REGISTER_PRIMITIVE_C(kNameBatchNorm, BatchNorm);
} // namespace ops
} // namespace mindspore

View File

@ -14,17 +14,18 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_BATCH_NORMAL_H_
#define MINDSPORE_CORE_C_OPS_BATCH_NORMAL_H_
#ifndef MINDSPORE_CORE_OPS_BATCH_NORMAL_H_
#define MINDSPORE_CORE_OPS_BATCH_NORMAL_H_
#include <map>
#include <vector>
#include <memory>
#include <string>
#include "c_ops/op_utils.h"
#include "c_ops/primitive_c.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
namespace mindspore {
namespace ops {
constexpr auto kNameBatchNorm = "BatchNorm";
class BatchNorm : public PrimitiveC {
public:
@ -34,19 +35,23 @@ class BatchNorm : public PrimitiveC {
}
~BatchNorm() = default;
MS_DECLARE_PARENT(BatchNorm, PrimitiveC);
void Init(bool is_training = false, float epsilon = 1e-5, const Format &format = NCHW);
void set_is_training(bool is_training);
void set_epsilon(float epsilon);
void Init(const bool is_training = false, const float epsilon = 1e-5, const float momentun = 0.1,
const Format &format = NCHW);
void set_is_training(const bool is_training);
void set_epsilon(const float epsilon);
void set_format(const Format &format);
bool get_is_trainging();
float get_epsilon();
void set_momentum(const float momentum);
bool get_is_training() const;
float get_epsilon() const;
Format get_format() const;
float get_momentum() const;
};
AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimBatchNormPtr = std::shared_ptr<BatchNorm>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_BatchNorm_H_
#endif // MINDSPORE_CORE_OPS_BatchNorm_H_

View File

@ -0,0 +1,116 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <map>
#include <string>
#include "ops/batch_norm_fold.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
void BatchNormFold::Init(const float momentum, const float epsilon, const bool is_training, const int64_t freeze_bn) {
set_momentum(momentum);
set_epsilon(epsilon);
set_is_training(is_training);
set_freeze_bn(freeze_bn);
}
void BatchNormFold::set_momentum(const float momentum) {
CheckAndConvertUtils::CheckInRange<int64_t>(kMomentum, momentum, kIncludeBoth, {0.0, 1.0}, this->name());
this->AddAttr(kMomentum, MakeValue(momentum));
}
float BatchNormFold::get_momentum() const {
auto value_ptr = GetAttr(kMomentum);
return GetValue<float>(value_ptr);
}
void BatchNormFold::set_epsilon(const float epsilon) {
float match_value = 0.0;
CheckAndConvertUtils::CheckValue(kEpsilon, epsilon, kGreaterThan, match_value, this->name());
this->AddAttr(kEpsilon, MakeValue(epsilon));
}
float BatchNormFold::get_epsilon() const {
auto value_ptr = GetAttr(kEpsilon);
return GetValue<float>(value_ptr);
}
void BatchNormFold::set_is_training(const bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); }
bool BatchNormFold::get_is_training() const {
auto value_ptr = GetAttr(kIsTraining);
return GetValue<bool>(value_ptr);
}
void BatchNormFold::set_freeze_bn(const int64_t freeze_bn) { this->AddAttr(kFreezeBn, MakeValue(freeze_bn)); }
int64_t BatchNormFold::get_freeze_bn() const {
auto value_ptr = GetAttr(kFreezeBn);
return GetValue<int64_t>(value_ptr);
}
AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto BatchNormFold_prim = primitive->cast<PrimBatchNormFoldPtr>();
MS_EXCEPTION_IF_NULL(BatchNormFold_prim);
auto op_name = BatchNormFold_prim->name();
auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShape("mean_shape", input_args[1]->BuildShape(), op_name);
auto variance_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("variance_shape", input_args[2]->BuildShape(), op_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
auto global_step_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("global_step_shape", input_args[3]->BuildShape(), op_name);
CheckAndConvertUtils::Check("mean_shape", mean_shape, kEqual, "gamma_shape", variance_shape, op_name);
CheckAndConvertUtils::Check("mean_shape[0]", mean_shape[0], kEqual, "input channel", x_shape[1], op_name);
CheckAndConvertUtils::CheckInteger("global step shape len", global_step_shape.size(), kEqual, 1, op_name);
auto mean_type = input_args[1]->BuildType();
auto variance_type = input_args[2]->BuildType();
auto x_type = input_args[0]->BuildType();
auto global_step_type = input_args[3]->BuildType();
std::map<std::string, TypePtr> args = {{"x", x_type}, {"mean", mean_type}, {"variance", variance_type}};
CheckAndConvertUtils::CheckTensorTypeSame(args, {kNumberTypeFloat16, kNumberTypeFloat32}, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("gloabal_step", global_step_type, {kNumberTypeInt32}, op_name);
auto tensor_type0 = x_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type0);
auto element0 = tensor_type0->element();
auto tensor_type1 = mean_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type1);
auto element1 = tensor_type1->element();
auto tensor_type2 = variance_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type2);
auto element2 = tensor_type2->element();
CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "mean_type", element1->type_id(), op_name);
CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "variance_type", element2->type_id(), op_name);
auto output = std::make_shared<abstract::AbstractTensor>(element0, mean_shape);
AbstractBasePtrList output1 = {output, output, output, output};
return std::make_shared<abstract::AbstractTuple>(output1);
}
REGISTER_PRIMITIVE_EVAL_IMPL(BatchNormFold, prim::kPrimBatchNormFold, BatchNormFoldInfer);
REGISTER_PRIMITIVE_C(kNameBatchNormFold, BatchNormFold);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_BATCH_NORM_FOLD_H_
#define MINDSPORE_CORE_OPS_BATCH_NORM_FOLD_H_
#include <memory>
#include <vector>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameBatchNormFold = "BatchNormFold";
class BatchNormFold : public PrimitiveC {
public:
BatchNormFold() : PrimitiveC(kNameBatchNormFold) {
InitIOName({"x", "mean", "variance", "global_step"}, {"batch_mean", "batch_std", "running_mean", "running_std"});
}
~BatchNormFold() = default;
MS_DECLARE_PARENT(BatchNormFold, PrimitiveC);
void Init(const float momentum = 0.9, const float epsilon = 1e-5, const bool is_training = true,
const int64_t freeze_bn = 0);
void set_momentum(const float momentum);
void set_epsilon(const float epsilon);
void set_is_training(const bool is_training);
void set_freeze_bn(const int64_t freeze_bn);
float get_momentum() const;
float get_epsilon() const;
bool get_is_training() const;
int64_t get_freeze_bn() const;
};
AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimBatchNormFoldPtr = std::shared_ptr<BatchNormFold>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_BATCH_NORM_FOLD_H_

View File

@ -0,0 +1,81 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/batch_to_space.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
void BatchToSpace::Init(const std::vector<int64_t> &block_size, const std::vector<std::vector<int64_t>> &crops) {
this->set_block_size(block_size);
this->set_crops(crops);
}
void BatchToSpace::set_block_size(const std::vector<int64_t> &block_size) {
this->AddAttr(kBlockSize, MakeValue(block_size));
}
std::vector<int64_t> BatchToSpace::get_block_size() const {
auto value_ptr = this->GetAttr(kBlockSize);
return GetValue<std::vector<int64_t>>(value_ptr);
}
void BatchToSpace::set_crops(const std::vector<std::vector<int64_t>> &crops) {
this->AddAttr(kCrops, MakeValue(crops));
}
std::vector<std::vector<int64_t>> BatchToSpace::get_crops() const {
auto value_ptr = this->GetAttr(kCrops);
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
}
AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim = primitive->cast<PrimBatchToSpacePtr>();
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
auto block_size = prim->get_block_size();
auto crops = prim->get_crops();
auto out_shape = x_shape;
for (size_t i = 0; i < 2; ++i) {
auto x_block_prod = out_shape[i + 2] * block_size[i];
auto crops_sum = crops[i][0] + crops[i][1];
CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, "crops sum", 4, prim_name);
out_shape[i + 2] = x_block_prod - crops_sum;
}
CheckAndConvertUtils::CheckInteger("x_shape[0] % (block_size[0]*block_size[1])",
out_shape[0] % (block_size[0] * block_size[1]), kEqual, 0, prim_name);
out_shape[0] /= block_size[0] * block_size[1];
auto ret = input_args[0]->Broaden();
ret->set_shape(std::make_shared<abstract::Shape>(out_shape));
return ret;
}
REGISTER_PRIMITIVE_EVAL_IMPL(BatchToSpace, prim::kPrimBatchToSpace, BatchToSpaceInfer);
REGISTER_PRIMITIVE_C(kNameBatchToSpace, BatchToSpace);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_BATCH_TO_SPACE_H_
#define MINDSPORE_CORE_OPS_BATCH_TO_SPACE_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameBatchToSpace = "BatchToSpace";
class BatchToSpace : public PrimitiveC {
public:
BatchToSpace() : PrimitiveC(kNameBatchToSpace) {}
~BatchToSpace() = default;
MS_DECLARE_PARENT(BatchToSpace, PrimitiveC);
void Init(const std::vector<int64_t> &block_size, const std::vector<std::vector<int64_t>> &crops);
void set_block_size(const std::vector<int64_t> &block_size);
void set_crops(const std::vector<std::vector<int64_t>> &crops);
std::vector<int64_t> get_block_size() const;
std::vector<std::vector<int64_t>> get_crops() const;
};
AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimBatchToSpacePtr = std::shared_ptr<BatchToSpace>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_BATCH_TO_SPACE_H_

View File

@ -0,0 +1,109 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/batch_to_space_nd.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto batch_prim = primitive->cast<PrimBatchToSpaceNDPtr>();
MS_EXCEPTION_IF_NULL(batch_prim);
auto prim_name = batch_prim->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name);
auto out_shape = x_shape;
int64_t block_shape_prod = 1;
int64_t offset = 2;
auto block_shape = batch_prim->get_block_shape();
auto crops = batch_prim->get_crops();
int64_t size = block_shape.size();
for (int64_t i = 0; i < size; i++) {
block_shape_prod = block_shape_prod * block_shape[i];
auto x_block_prod = out_shape[i + offset] * block_shape[i];
auto crops_sum = crops[i][0] + crops[i][1];
CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, "crops sum", crops_sum, prim_name);
out_shape[i + offset] = x_block_prod - crops_sum;
}
if (out_shape[0] % block_shape_prod != 0) {
MS_EXCEPTION(ValueError) << prim_name << " input_x dimension 0 " << out_shape[0]
<< " should be divisible by block_shape_prod " << block_shape_prod;
}
out_shape[0] = int64_t(floor(out_shape[0] / block_shape_prod));
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = input_args[0]->BuildType();
return infer_type;
}
} // namespace
void BatchToSpaceND::set_crops(std::vector<std::vector<int64_t>> crops) {
CheckAndConvertUtils::CheckInteger(kCrops, crops.size(), kEqual, 2, this->name());
int64_t h = crops.size();
int64_t w = crops[0].size();
std::vector<int64_t> temp_w = {2, 2};
CheckAndConvertUtils::Check(kCrops, {h, w}, kEqual, "paddings_shape(2,2)", temp_w, this->name());
for (int64_t i = 0; i < h; i++) {
for (int64_t j = 0; j < w; j++) {
CheckAndConvertUtils::CheckInteger(kCrops, crops[i][j], kGreaterEqual, 0, this->name());
}
}
this->AddAttr(kCrops, MakeValue(crops));
}
std::vector<std::vector<int64_t>> BatchToSpaceND::get_crops() const {
auto value_ptr = GetAttr(kCrops);
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
}
void BatchToSpaceND::set_block_shape(std::vector<int64_t> block_shape) {
CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape.size(), kEqual, 2, this->name());
for (int64_t i = 0; i < (int64_t)block_shape.size(); i++) {
CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape[i], kGreaterEqual, 1, this->name());
}
this->AddAttr(kBlockShape, MakeValue(block_shape));
}
std::vector<int64_t> BatchToSpaceND::get_block_shape() const {
auto value_ptr = GetAttr(kBlockShape);
return GetValue<std::vector<int64_t>>(value_ptr);
}
void BatchToSpaceND::Init(std::vector<int64_t> block_shape, std::vector<std::vector<int64_t>> crops) {
this->set_crops(crops);
this->set_block_shape(block_shape);
}
AbstractBasePtr BatchToSpaceNDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(BatchToSpaceND, prim::kPrimBatchToSpaceND, BatchToSpaceNDInfer);
REGISTER_PRIMITIVE_C(kNameBatchToSpaceND, BatchToSpaceND);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_BATCH_TO_SPACE_ND_H_
#define MINDSPORE_CORE_OPS_BATCH_TO_SPACE_ND_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameBatchToSpaceND = "BatchToSpaceND";
class BatchToSpaceND : public PrimitiveC {
public:
BatchToSpaceND() : PrimitiveC(kNameBatchToSpaceND) {}
~BatchToSpaceND() = default;
MS_DECLARE_PARENT(BatchToSpaceND, PrimitiveC);
void Init(std::vector<int64_t> block_shape, std::vector<std::vector<int64_t>> crops);
void set_crops(std::vector<std::vector<int64_t>> crops);
void set_block_shape(std::vector<int64_t> block_shape);
std::vector<int64_t> get_block_shape() const;
std::vector<std::vector<int64_t>> get_crops() const;
};
AbstractBasePtr BatchToSpaceNDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimBatchToSpaceNDPtr = std::shared_ptr<BatchToSpaceND>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_BATCH_TO_SPACE_ND_H_

View File

@ -14,12 +14,15 @@
* limitations under the License.
*/
#include "c_ops/bias_add.h"
#include "ops/bias_add.h"
#include <memory>
#include "c_ops/op_utils.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
// Add
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
void BiasAdd::set_format(const Format &format) {
int64_t f = format;
this->AddAttr(kFormat, MakeValue(f));
@ -29,5 +32,7 @@ Format BiasAdd::get_format() const {
return Format(GetValue<int64_t>(value_ptr));
}
void BiasAdd::Init(const Format &format) { this->set_format(format); }
// Add
REGISTER_PRIMITIVE_C(kNameBiasAdd, BiasAdd);
} // namespace ops
} // namespace mindspore

View File

@ -14,17 +14,20 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_BIASADD_H_
#define MINDSPORE_CORE_C_OPS_BIASADD_H_
#ifndef MINDSPORE_CORE_OPS_BIAS_ADD_H_
#define MINDSPORE_CORE_OPS_BIAS_ADD_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "c_ops/primitive_c.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
// Add
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameBiasAdd = "BiasAdd";
class BiasAdd : public PrimitiveC {
public:
@ -35,6 +38,7 @@ class BiasAdd : public PrimitiveC {
void set_format(const Format &format);
Format get_format() const;
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_BIASADD_H_
#endif // MINDSPORE_CORE_OPS_BIAS_ADD_H_

View File

@ -0,0 +1,94 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include <map>
#include "ops/binary_cross_entropy.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto binary_cross_entropy_prim = primitive->cast<PrimBinaryCrossEntropyPtr>();
MS_EXCEPTION_IF_NULL(binary_cross_entropy_prim);
auto prim_name = binary_cross_entropy_prim->name();
CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
auto weight_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("weight_shape", input_args[2]->BuildShape(), prim_name);
CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
std::vector<int64_t> infer_shape;
if (weight_shape.size() < 1) {
CheckAndConvertUtils::Check("x shape", y_shape, kEqual, "weight shape", weight_shape, prim_name);
}
if (binary_cross_entropy_prim->get_reduction() != REDUCTION_SUM &&
binary_cross_entropy_prim->get_reduction() != MEAN) {
infer_shape = {x_shape.begin(), infer_shape.end()};
}
return std::make_shared<abstract::Shape>(infer_shape);
}
TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInteger("binary_cross_entropy_infer", input_args.size(), kEqual, 3, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
std::map<std::string, TypePtr> types;
types.emplace("x_shape", input_args[0]->BuildType());
types.emplace("y_shape", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (input_args[3]->BuildType() != nullptr) {
types.emplace("x_shape", input_args[0]->BuildType());
types.emplace("weight_shape", input_args[2]->BuildType());
infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
return TypeIdToType(infer_type);
}
} // namespace
void BinaryCrossEntropy::set_reduction(const Reduction &reduction) {
int64_t swi = reduction;
this->AddAttr(kReduction, MakeValue(swi));
}
Reduction BinaryCrossEntropy::get_reduction() const {
auto value_ptr = GetAttr(kReduction);
return Reduction(GetValue<int64_t>(value_ptr));
}
void BinaryCrossEntropy::Init(const Reduction &reduction) { this->set_reduction(reduction); }
AbstractBasePtr BinaryCrossEntropyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(BinaryCrossEntroyInferType(primitive, input_args),
BinaryCrossEntroyInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(BinaryCrossEntropy, prim::kPrimBinaryCrossEntropy, BinaryCrossEntropyInfer);
REGISTER_PRIMITIVE_C(kNameBinaryCrossEntropy, BinaryCrossEntropy);
} // namespace ops
} // namespace mindspore

View File

@ -14,25 +14,32 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_GRAD_H_
#define MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_GRAD_H_
#ifndef MINDSPORE_CORE_OPS_BINARY_CROSS_ENTROPY_H_
#define MINDSPORE_CORE_OPS_BINARY_CROSS_ENTROPY_H_
#include <string>
#include <vector>
#include <memory>
#include "c_ops/primitive_c.h"
#include "c_ops/op_utils.h"
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameBinaryCrossEntropy = "BinaryCrossEntropy";
class BinaryCrossEntropy : public PrimitiveC {
public:
BinaryCrossEntropy() : PrimitiveC(kNameBinaryCrossEntropy) {}
~BinaryCrossEntropy() = default;
MS_DECLARE_PARENT(BinaryCrossEntropy, PrimitiveC);
void Init(const std::string &reduction = "mean");
void set_reduction(const std::string &reduction);
std::string get_reduction() const;
void Init(const Reduction &reduction = MEAN);
void set_reduction(const Reduction &reduction);
Reduction get_reduction() const;
};
AbstractBasePtr BinaryCrossEntropyGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimBinaryCrossEntropyPtr = std::shared_ptr<BinaryCrossEntropy>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_GRAD_H_
#endif // MINDSPORE_CORE_OPS_BINARY_CROSS_ENTROPY_H_

View File

@ -14,13 +14,14 @@
* limitations under the License.
*/
#include "c_ops/black_box.h"
#include "c_ops/op_utils.h"
#include "ops/black_box.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
void BlackBox::Init(const std::string &id, int64_t size, const std::vector<int64_t> &address) {
namespace ops {
void BlackBox::Init(const std::string &id, const int64_t size, const std::vector<int64_t> &address) {
this->set_id(id);
this->set_size(size);
this->set_address(address);
@ -33,7 +34,7 @@ std::string BlackBox::get_id() const {
return GetValue<std::string>(value_ptr);
}
void BlackBox::set_size(int64_t size) { this->AddAttr(kSize, MakeValue(size)); }
void BlackBox::set_size(const int64_t size) { this->AddAttr(kSize, MakeValue(size)); }
int64_t BlackBox::get_size() const {
auto value_ptr = this->GetAttr(kSize);
@ -47,4 +48,5 @@ std::vector<int64_t> BlackBox::get_address() const {
return GetValue<std::vector<int64_t>>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameBlackBox, BlackBox);
} // namespace ops
} // namespace mindspore

View File

@ -14,25 +14,26 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_BLACKBOX_H_
#define MINDSPORE_CORE_C_OPS_BLACKBOX_H_
#ifndef MINDSPORE_CORE_OPS_BLACK_BOX_H_
#define MINDSPORE_CORE_OPS_BLACK_BOX_H_
#include <string>
#include <vector>
#include <memory>
#include "c_ops/primitive_c.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameBlackBox = "BlackBox";
class BlackBox : public PrimitiveC {
public:
BlackBox() : PrimitiveC(kNameBlackBox) {}
~BlackBox() = default;
MS_DECLARE_PARENT(BlackBox, PrimitiveC);
void Init(const std::string &id, int64_t size, const std::vector<int64_t> &address);
void Init(const std::string &id, const int64_t size, const std::vector<int64_t> &address);
void set_id(const std::string &id);
void set_size(int64_t size);
void set_size(const int64_t size);
void set_address(const std::vector<int64_t> &address);
std::string get_id() const;
int64_t get_size() const;
@ -40,6 +41,7 @@ class BlackBox : public PrimitiveC {
};
using PrimBlackBoxPtr = std::shared_ptr<BlackBox>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_BLACKBOX_H_
#endif // MINDSPORE_CORE_OPS_BLACK_BOX_H_

View File

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <vector>
#include <memory>
#include "ops/broadcast.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
void Broadcast::Init(const int64_t root_rank, const std::string &group) {
this->set_root_rank(root_rank);
this->set_group(group);
}
void Broadcast::set_root_rank(const int64_t root_rank) { this->AddAttr(kKeepProb, MakeValue(root_rank)); }
void Broadcast::set_group(const std::string &group) {
CheckAndConvertUtils::CheckString(kGroup, group, {"hccl_world_group", "hccl_world_group"}, this->name());
this->AddAttr(kGroup, MakeValue(group));
}
int64_t Broadcast::get_root_rank() const {
auto value_ptr = this->GetAttr(kRootRank);
return GetValue<float>(value_ptr);
}
std::string Broadcast::get_group() const {
auto value_ptr = this->GetAttr(kGroup);
return GetValue<std::string>(value_ptr);
}
AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto broadcast_prim = primitive->cast<PrimBroadcast>();
MS_EXCEPTION_IF_NULL(broadcast_prim);
auto prim_name = broadcast_prim->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
// infer shape
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
// infer type
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
std::vector<TypePtr> output_types;
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
for (size_t i = 0; i < input_args.size(); i++) {
auto out_type = input_args[i]->BuildType()->cast<TensorTypePtr>()->element();
output_types.push_back(out_type);
CheckAndConvertUtils::CheckTensorTypeValid("index_type", out_type, valid_types, prim_name);
}
return std::make_shared<abstract::AbstractTensor>(x_type, in_shape);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Broadcast, prim::kPrimBroadcast, BroadcastInfer);
REGISTER_PRIMITIVE_C(kNameBroadcast, Broadcast);
} // namespace ops
} // namespace mindspore

View File

@ -14,27 +14,34 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_BROADCAST_H_
#define MINDSPORE_CORE_C_OPS_BROADCAST_H_
#ifndef MINDSPORE_CORE_OPS_BROADCAST_H_
#define MINDSPORE_CORE_OPS_BROADCAST_H_
#include <string>
#include <vector>
#include <memory>
#include "c_ops/primitive_c.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameBroadcast = "Broadcast";
class Broadcast : public PrimitiveC {
public:
Broadcast() : PrimitiveC(kNameBroadcast) {}
~Broadcast() = default;
MS_DECLARE_PARENT(Broadcast, PrimitiveC);
void Init(int64_t root_rank, const std::string &group = "hccl_world_group");
void set_root_rank(int64_t root_rank);
void Init(const int64_t root_rank, const std::string &group = "hccl_world_group");
void set_root_rank(const int64_t root_rank);
void set_group(const std::string &group);
int64_t get_root_rank();
int64_t get_root_rank() const;
std::string get_group() const;
};
AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimBroadcast = std::shared_ptr<Broadcast>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_BROADCAST_H_
#endif // MINDSPORE_CORE_OPS_BROADCAST_H_

View File

@ -0,0 +1,88 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include "ops/broadcast_to.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto broad_cast_to = primitive->cast<PrimBroadcastToPtr>();
MS_EXCEPTION_IF_NULL(broad_cast_to);
auto prim_name = broad_cast_to->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto input_x = broad_cast_to->get_shape();
int64_t outer_dim_offset = input_x.size() - x_shape.size();
CheckAndConvertUtils::Check("x shape", x_shape, kLessEqual, "input_x", input_x, prim_name);
bool flag = true;
if (input_x.end() == find(input_x.begin(), input_x.end(), -1)) {
flag = false;
} else {
flag = true;
}
if (flag == true) {
for (int64_t i = 0; i < (int64_t)input_x.size(); i++) {
if (input_x[i] == -1) {
if (i < outer_dim_offset) {
MS_EXCEPTION(ValueError) << " -1 in init shape is in an incompatible "
"location with given input tensor, -1 index in init shape: "
<< i << " but -1 can only be in index" << x_shape.size()
<< "onwards for this input.";
}
input_x[i] = x_shape[i - outer_dim_offset];
}
}
}
std::reverse(input_x.begin(), input_x.end());
return std::make_shared<abstract::Shape>(input_x);
}
TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)};
CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name());
auto infer_dtype = input_args[0]->BuildType()->type_id();
return TypeIdToType(infer_dtype);
}
} // namespace
void BroadcastTo::Init(const std::vector<int64_t> &shape) { set_shape(shape); }
void BroadcastTo::set_shape(const std::vector<int64_t> &shape) {
CheckAndConvertUtils::CheckInteger(kShapeSize, shape.size(), kGreaterThan, 0, name());
AddAttr(kShape, MakeValue(shape));
}
std::vector<int64_t> BroadcastTo::get_shape() const {
auto value_ptr = GetAttr(kShape);
return GetValue<std::vector<int64_t>>(value_ptr);
}
AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(BroadcastToInferType(primitive, input_args),
BroadcastToInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastTo, prim::kPrimBroadcastTo, BroadcastToInfer);
REGISTER_PRIMITIVE_C(kNameBroadcastTo, BroadcastTo);
} // namespace ops
} // namespace mindspore

View File

@ -14,18 +14,19 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_BROADCAST_H_
#define MINDSPORE_CORE_C_OPS_BROADCAST_H_
#ifndef MINDSPORE_CORE_OPS_BROADCAST_TO_H_
#define MINDSPORE_CORE_OPS_BROADCAST_TO_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "c_ops/op_utils.h"
#include "c_ops/primitive_c.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameBroadcastTo = "BroadcastTo";
class BroadcastTo : public PrimitiveC {
public:
@ -41,6 +42,7 @@ AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const Prim
const std::vector<AbstractBasePtr> &input_args);
using PrimBroadcastToPtr = std::shared_ptr<BroadcastTo>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_BROADCAST_H_
#endif // MINDSPORE_CORE_OPS_BROADCAST_TO_H_

View File

@ -14,8 +14,10 @@
* limitations under the License.
*/
#include "c_ops/cast.h"
#include "ops/cast.h"
namespace mindspore {
namespace ops {
REGISTER_PRIMITIVE_C(kNameCast, Cast);
}
} // namespace ops
} // namespace mindspore

View File

@ -14,17 +14,18 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_CAST_H_
#define MINDSPORE_CORE_C_OPS_CAST_H_
#ifndef MINDSPORE_CORE_OPS_CAST_H_
#define MINDSPORE_CORE_OPS_CAST_H_
#include <vector>
#include <memory>
#include "c_ops/primitive_c.h"
#include "c_ops/op_utils.h"
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCast = "Cast";
class Cast : public PrimitiveC {
public:
@ -35,5 +36,6 @@ class Cast : public PrimitiveC {
AbstractBasePtr CastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimCast = std::shared_ptr<Cast>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_CAST_H_
#endif // MINDSPORE_CORE_OPS_CAST_H_

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <algorithm>
#include <memory>
#include <vector>
#include "ops/ceil.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Ceil");
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
auto infer_type = input_args[0]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name());
MS_EXCEPTION_IF_NULL(infer_type);
auto tensor_type = infer_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto data_type = tensor_type->element();
MS_EXCEPTION_IF_NULL(data_type);
return std::make_shared<abstract::AbstractTensor>(data_type, x_shape);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Ceil, prim::kPrimCeil, CeilInfer);
REGISTER_PRIMITIVE_C(kNameCeil, Ceil);
} // namespace ops
} // namespace mindspore

View File

@ -14,26 +14,29 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_CEIL_H_
#define MINDSPORE_CORE_C_OPS_CEIL_H_
#ifndef MINDSPORE_CORE_OPS_CEIL_H_
#define MINDSPORE_CORE_OPS_CEIL_H_
#include <vector>
#include <memory>
#include "c_ops/primitive_c.h"
#include "c_ops/op_utils.h"
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCeil = "Ceil";
class Ceil : public PrimitiveC {
public:
Ceil() : PrimitiveC(kNameCeil) { InitIOName({"x"}, {"y"}); }
~Ceil() = default;
MS_DECLARE_PARENT(Ceil, PrimitiveC);
void init() {}
};
AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimCeil = std::shared_ptr<Ceil>;
using PrimCeilPtr = std::shared_ptr<Ceil>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_CEIL_H_
#endif // MINDSPORE_CORE_OPS_CEIL_H_

View File

@ -14,12 +14,13 @@
* limitations under the License.
*/
#include "c_ops/clip.h"
#include "ops/clip.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "c_ops/op_utils.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
void Clip::Init(const float max, const float min) {
this->set_max(max);
this->set_min(min);
@ -39,4 +40,5 @@ float Clip::get_min() const {
return GetValue<float>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameClip, Clip);
} // namespace ops
} // namespace mindspore

View File

@ -13,15 +13,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_CLIP_H_
#define MINDSPORE_CORE_C_OPS_CLIP_H_
#ifndef MINDSPORE_CORE_OPS_CLIP_H_
#define MINDSPORE_CORE_OPS_CLIP_H_
#include <memory>
#include "c_ops/primitive_c.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameClip = "Clip";
class Clip : public PrimitiveC {
public:
@ -36,6 +37,7 @@ class Clip : public PrimitiveC {
};
using PrimClipPtr = std::shared_ptr<Clip>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_CLIP_H_
#endif // MINDSPORE_CORE_OPS_CLIP_H_

View File

@ -0,0 +1,85 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <map>
#include <string>
#include "ops/concat.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
void Concat::Init(const int64_t axis) { this->set_axis(axis); }
int64_t Concat::get_axis() const {
auto value_ptr = this->GetAttr(kAxis);
return GetValue<int64_t>(value_ptr);
}
void Concat::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); }
AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim = primitive->cast<PrimConcatPtr>();
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto input_tuple = input_args[0]->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(input_tuple);
auto elements = input_tuple->elements();
CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name);
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(element0);
auto element0_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name);
auto element0_rank = SizeToLong(element0_shape.size());
auto axis = prim->get_axis();
CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank},
prim_name);
axis = axis < 0 ? axis + element0_rank : axis;
std::map<std::string, TypePtr> types;
types.emplace("element0", element0->BuildType());
int64_t all_shp = element0_shape[axis];
for (size_t i = 1; i < elements.size(); ++i) {
std::string elementi = "element" + std::to_string(i);
auto elementi_shape =
CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(),
prim_name);
for (int64_t j = 0; j < element0_rank; ++j) {
if (j != axis && elementi_shape[j] != element0_shape[j]) {
MS_LOG(EXCEPTION) << "element " << i << " shape in input can not concat with first element.";
}
}
all_shp = all_shp == -1 || elementi_shape[axis] == -1 ? -1 : all_shp + elementi_shape[axis];
types.emplace(elementi, elements[i]->BuildType());
}
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, all_types, prim_name);
auto ret_shape = element0_shape;
ret_shape[axis] = all_shp;
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(infer_type),
std::make_shared<abstract::Shape>(ret_shape));
}
REGISTER_PRIMITIVE_EVAL_IMPL(Concat, prim::kPrimConcat, ConcatInfer);
REGISTER_PRIMITIVE_C(kNameConcat, Concat);
} // namespace ops
} // namespace mindspore

View File

@ -14,29 +14,31 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_CONCAT_H_
#define MINDSPORE_CORE_C_OPS_CONCAT_H_
#ifndef MINDSPORE_CORE_OPS_CONCAT_H_
#define MINDSPORE_CORE_OPS_CONCAT_H_
#include <vector>
#include <memory>
#include "c_ops/primitive_c.h"
#include "c_ops/op_utils.h"
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameConcat = "Concat";
class Concat : public PrimitiveC {
public:
Concat() : PrimitiveC(kNameConcat) {}
~Concat() = default;
MS_DECLARE_PARENT(Concat, PrimitiveC);
void Init(int64_t axis = 0);
void set_axis(int64_t axis);
void Init(const int64_t axis = 0);
void set_axis(const int64_t axis);
int64_t get_axis() const;
};
AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimConcatPtr = std::shared_ptr<Concat>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_CONCAT_H_
#endif // MINDSPORE_CORE_OPS_CONCAT_H_

View File

@ -0,0 +1,58 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <map>
#include <string>
#include <memory>
#include "ops/constant.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto x = input_args[0]->BuildShape();
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
}
} // namespace
AbstractBasePtr ConstantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Constant, prim::kPrimConstant, ConstantInfer);
REGISTER_PRIMITIVE_C(kNameConstant, Constant);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CONSTANT_H_
#define MINDSPORE_CORE_OPS_CONSTANT_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameConstant = "Constant";
class Constant : public PrimitiveC {
public:
Constant() : PrimitiveC(kNameConstant) {}
~Constant() = default;
MS_DECLARE_PARENT(Constant, PrimitiveC);
void Init() {}
};
AbstractBasePtr ConstantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimConstantPtr = std::shared_ptr<Constant>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CONSTANT_H_

View File

@ -14,12 +14,30 @@
* limitations under the License.
*/
#include "c_ops/constant_of_shape.h"
#include "c_ops/op_utils.h"
#include "ops/constant_of_shape.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kEqual, 1, "ConstantOfShape");
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "ConstantOfShape");
return std::make_shared<abstract::Shape>(input_shape);
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto constant_prim = primitive->cast<PrimConstantOfShapePtr>();
MS_EXCEPTION_IF_NULL(constant_prim);
auto data_type = TypeId(constant_prim->get_data_type());
return TypeIdToType(data_type);
}
} // namespace
void ConstantOfShape::Init(int64_t data_type, const std::vector<float> &value) {
this->set_data_type(data_type);
this->set_value(value);
@ -38,5 +56,12 @@ std::vector<float> ConstantOfShape::get_value() const {
auto value_ptr = this->GetAttr(kValue);
return GetValue<std::vector<float>>(value_ptr);
}
AbstractBasePtr ConstantOfShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(ConstantOfShape, prim::kPrimConstantOfShape, ConstantOfShapeInfer);
REGISTER_PRIMITIVE_C(kNameConstantOfShape, ConstantOfShape);
} // namespace ops
} // namespace mindspore

View File

@ -14,15 +14,16 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_CONSTANTOFSHAPE_H_
#define MINDSPORE_CORE_C_OPS_CONSTANTOFSHAPE_H_
#ifndef MINDSPORE_CORE_OPS_CONSTANT_OF_SHAPE_H_
#define MINDSPORE_CORE_OPS_CONSTANT_OF_SHAPE_H_
#include <memory>
#include <vector>
#include "c_ops/primitive_c.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameConstantOfShape = "ConstantOfShape";
class ConstantOfShape : public PrimitiveC {
public:
@ -36,7 +37,10 @@ class ConstantOfShape : public PrimitiveC {
std::vector<float> get_value() const;
};
AbstractBasePtr ConstantOfShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimConstantOfShapePtr = std::shared_ptr<ConstantOfShape>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_CONSTANTOFSHAPE_H_
#endif // MINDSPORE_CORE_OPS_CONSTANT_OF_SHAPE_H_

View File

@ -14,19 +14,21 @@
* limitations under the License.
*/
#include "c_ops/control_depend.h"
#include "ops/control_depend.h"
namespace mindspore {
void ControlDepend::Init(int64_t depend_mode) { this->set_depend_mode(depend_mode); }
namespace ops {
void ControlDepend::Init(const int64_t depend_mode) { this->set_depend_mode(depend_mode); }
void ControlDepend::set_depend_mode(int64_t depend_mode) {
CheckAndConvertUtils::CheckInRange(kDependMode, depend_mode, kIncludeBoth, {0, 1}, name());
void ControlDepend::set_depend_mode(const int64_t depend_mode) {
CheckAndConvertUtils::CheckInRange<int64_t>(kDependMode, depend_mode, kIncludeBoth, {0, 1}, name());
AddAttr(kDependMode, MakeValue(depend_mode));
}
int64_t ControlDepend::get_depend_mode() {
int64_t ControlDepend::get_depend_mode() const {
auto value_ptr = GetAttr(kDependMode);
return GetValue<int64_t>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameControlDepend, ControlDepend);
} // namespace ops
} // namespace mindspore

View File

@ -14,29 +14,29 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_CONTROL_DEPEND_H_
#define MINDSPORE_CORE_C_OPS_CONTROL_DEPEND_H_
#ifndef MINDSPORE_CORE_OPS_CONTROL_DEPEND_H_
#define MINDSPORE_CORE_OPS_CONTROL_DEPEND_H_
#include <vector>
#include <memory>
#include "c_ops/primitive_c.h"
#include "c_ops/op_utils.h"
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameControlDepend = "ControlDepend";
class ControlDepend : public PrimitiveC {
public:
ControlDepend() : PrimitiveC(kNameControlDepend) {}
~ControlDepend() = default;
MS_DECLARE_PARENT(ControlDepend, PrimitiveC);
void Init(int64_t depend_mode);
void set_depend_mode(int64_t depend_mode);
int64_t get_depend_mode();
void Init(const int64_t depend_mode);
void set_depend_mode(const int64_t depend_mode = 0);
int64_t get_depend_mode() const;
};
AbstractBasePtr ControlDependInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimControlDepend = std::shared_ptr<ControlDepend>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_CONTROl_DEPEND_H_
#endif // MINDSPORE_CORE_OPS_CONTROl_DEPEND_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "c_ops/conv2d.h"
#include "ops/conv2d.h"
#include <string>
#include <algorithm>
#include <memory>
@ -23,102 +23,29 @@
#include "ir/dtype/tensor_type.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "ops/control_depend.h"
namespace mindspore {
Conv2D::Conv2D() : PrimitiveC(kNameConv2D) { InitIOName({"x", "w"}, {"output"}); }
void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode,
const std::string &pad_mode, const std::vector<int64_t> &pad, const std::vector<int64_t> &stride,
const std::vector<int64_t> &dilation, int64_t group) {
auto prim_name = this->name();
this->AddAttr("data_format", MakeValue("NCHW"));
this->AddAttr("offset_a", MakeValue(static_cast<int64_t>(0)));
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name));
this->set_stride(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), true, true));
this->set_dilation(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), true, true));
this->set_pad_mode(CheckAndConvertUtils::CheckString(kPadMode, pad_mode, {"valid", "same", "pad"}, prim_name));
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, prim_name);
if (pad_mode == "pad") {
for (auto item : pad) {
CheckAndConvertUtils::Check("pad_item", item, kGreaterEqual, "zeros_list", 0, prim_name);
}
} else {
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name);
}
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true));
this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 1, prim_name));
this->set_out_channel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name));
this->set_group(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name));
}
std::vector<int64_t> Conv2D::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> Conv2D::get_stride() const {
auto value_ptr = GetAttr(kStride);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> Conv2D::get_dilation() const {
auto value_ptr = GetAttr(kDilation);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::string Conv2D::get_pad_mode() const {
auto value_ptr = this->GetAttr(kPadMode);
return GetValue<string>(value_ptr);
}
std::vector<int64_t> Conv2D::get_pad() const {
auto value_ptr = this->GetAttr(kPad);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> Conv2D::get_pad_list() const {
auto value_ptr = this->GetAttr(kPadList);
return GetValue<std::vector<int64_t>>(value_ptr);
}
int64_t Conv2D::get_mode() const {
auto value_ptr = this->GetAttr(kMode);
return GetValue<int64_t>(value_ptr);
}
int64_t Conv2D::get_group() const {
auto value_ptr = this->GetAttr(kGroup);
return GetValue<int64_t>(value_ptr);
}
int64_t Conv2D::get_output_channel() const {
auto value_ptr = this->GetAttr(kOutputChannel);
return GetValue<int64_t>(value_ptr);
}
void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) {
this->AddAttr(kKernelSize, MakeValue(kernel_size));
}
void Conv2D::set_stride(const std::vector<int64_t> &stride) { this->AddAttr(kStride, MakeValue(stride)); }
void Conv2D::set_dilation(const std::vector<int64_t> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); }
void Conv2D::set_pad_mode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); }
void Conv2D::set_pad(const std::vector<int64_t> &pad) { this->AddAttr(kPad, MakeValue(pad)); }
void Conv2D::set_mode(int64_t mode) { this->AddAttr(kMode, MakeValue(mode)); }
void Conv2D::set_group(int64_t group) { this->AddAttr(kGroup, MakeValue(group)); }
void Conv2D::set_out_channel(int64_t output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); }
void Conv2D::set_pad_list(const std::vector<int64_t> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); }
namespace ops {
namespace {
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto conv_prim = primitive->cast<PrimConv2dPtr>();
MS_EXCEPTION_IF_NULL(conv_prim);
auto prim_name = conv_prim->name();
CheckAndConvertUtils::CheckInRange("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
if (conv_prim->get_format() == NHWC) {
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]};
}
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->get_group(), kEqual, "w_shape[1]",
w_shape[1], conv_prim->name());
auto out_channel = conv_prim->get_output_channel();
auto out_channel = conv_prim->get_out_channel();
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name());
std::vector<int64_t> temp_w;
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
@ -137,10 +64,10 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
int64_t w_out = -1;
std::vector<int64_t> pad_list(4, 0);
auto pad_mode = conv_prim->get_pad_mode();
if (pad_mode == "valid") {
if (pad_mode == VALID) {
h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
} else if (pad_mode == "same") {
} else if (pad_mode == SAME) {
h_out = ceil(x_shape[2] / stride_h);
w_out = ceil(x_shape[3] / stride_w);
@ -153,7 +80,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
auto pad_left = floor(pad_needed_w / 2);
pad_list.emplace_back(pad_left);
pad_list.emplace_back(pad_needed_h - pad_left);
} else if (pad_mode == "pad") {
} else if (pad_mode == PAD) {
std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list));
auto pad_top = conv_prim->get_pad()[0];
auto pad_bottom = conv_prim->get_pad()[1];
@ -165,13 +92,17 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
h_out = floor(h_out);
w_out = floor(w_out);
}
conv_prim->set_pad_list(pad_list);
conv_prim->set_pad(pad_list);
std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out};
if (conv_prim->get_format() == NHWC) {
out_shape = {x_shape[0], h_out, w_out, out_channel};
}
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeBoth, {2, 3}, prim->name());
CheckAndConvertUtils::CheckInRange<size_t>("", input_args.size(), kIncludeBoth, {2, 3}, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -186,12 +117,121 @@ TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
}
return TypeIdToType(infer_type);
}
} // namespace
void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, const PadMode &pad_mode,
const std::vector<int64_t> &pad, const std::vector<int64_t> &stride,
const std::vector<int64_t> &dilation, int64_t group, const Format &format) {
set_kernel_size(kernel_size);
set_stride(stride);
set_dilation(dilation);
set_pad(pad);
set_pad_mode(pad_mode);
set_mode(mode);
set_out_channel(out_channel);
set_group(group);
set_format(format);
}
void Conv2D::set_out_channel(int64_t out_channel) {
AddAttr(kOutChannel,
MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
}
void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) {
AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, name())));
}
void Conv2D::set_stride(const std::vector<int64_t> &stride) {
AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true)));
}
void Conv2D::set_dilation(const std::vector<int64_t> &dilation) {
AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true)));
}
void Conv2D::set_pad_mode(const PadMode &pad_mode) {
std::vector<int64_t> pad = get_pad();
if (pad_mode == PAD) {
for (auto item : pad) {
CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, "zeros_list", 0, name());
}
} else {
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, name());
}
int64_t swi = pad_mode;
AddAttr(kPadMode, MakeValue(swi));
}
void Conv2D::set_pad(const std::vector<int64_t> &pad) {
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true)));
}
void Conv2D::set_mode(int64_t mode) {
AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
}
void Conv2D::set_group(int64_t group) {
AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
}
void Conv2D::set_format(const Format &format) {
int64_t f = format;
AddAttr(kFormat, MakeValue(f));
}
int64_t Conv2D::get_out_channel() const {
auto value_ptr = GetAttr(kOutChannel);
return GetValue<int64_t>(value_ptr);
}
std::vector<int64_t> Conv2D::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> Conv2D::get_stride() const {
auto value_ptr = GetAttr(kStride);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> Conv2D::get_dilation() const {
auto value_ptr = GetAttr(kDilation);
return GetValue<std::vector<int64_t>>(value_ptr);
}
PadMode Conv2D::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode);
return PadMode(GetValue<int64_t>(value_ptr));
}
std::vector<int64_t> Conv2D::get_pad() const {
auto value_ptr = GetAttr(kPad);
return GetValue<std::vector<int64_t>>(value_ptr);
}
int64_t Conv2D::get_mode() const {
auto value_ptr = GetAttr(kMode);
return GetValue<int64_t>(value_ptr);
}
int64_t Conv2D::get_group() const {
auto value_ptr = GetAttr(kGroup);
return GetValue<int64_t>(value_ptr);
}
Format Conv2D::get_format() const {
auto value_ptr = GetAttr(kFormat);
return Format(GetValue<int64_t>(value_ptr));
}
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args),
Conv2dInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer);
REGISTER_PRIMITIVE_C(kNameConv2D, Conv2D);
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,49 +14,52 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_C_OPS_CONV2D_H_
#define MINDSPORE_CORE_C_OPS_CONV2D_H_
#ifndef MINDSPORE_CORE_OPS_CONV2D_H_
#define MINDSPORE_CORE_OPS_CONV2D_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "c_ops/op_utils.h"
#include "c_ops/primitive_c.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameConv2D = "Conv2D";
class Conv2D : public PrimitiveC {
public:
Conv2D();
Conv2D() : PrimitiveC(kNameConv2D) { InitIOName({"x", "w"}, {"output"}); }
explicit Conv2D(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x", "w"}, {"output"}); }
~Conv2D() = default;
MS_DECLARE_PARENT(Conv2D, PrimitiveC);
void Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode = 1,
const std::string &pad_mode = "valid", const std::vector<int64_t> &pad = {0, 0, 0, 0},
const PadMode &pad_mode = VALID, const std::vector<int64_t> &pad = {0, 0, 0, 0},
const std::vector<int64_t> &stride = {1, 1, 1, 1}, const std::vector<int64_t> &dilation = {1, 1, 1, 1},
int64_t group = 1);
std::vector<int64_t> get_kernel_size() const;
std::vector<int64_t> get_stride() const;
std::vector<int64_t> get_dilation() const;
std::string get_pad_mode() const;
std::vector<int64_t> get_pad() const;
std::vector<int64_t> get_pad_list() const;
int64_t get_mode() const;
int64_t get_group() const;
int64_t get_output_channel() const;
int64_t group = 1, const Format &format = NCHW);
void set_kernel_size(const std::vector<int64_t> &kernel_size);
void set_stride(const std::vector<int64_t> &stride);
void set_dilation(const std::vector<int64_t> &dilation);
void set_pad_mode(const std::string &pad_mode);
void set_pad_mode(const PadMode &pad_mode);
void set_pad(const std::vector<int64_t> &pad);
void set_mode(int64_t mode);
void set_group(int64_t group);
void set_out_channel(int64_t output_channel);
void set_pad_list(const std::vector<int64_t> &pad_list);
void set_out_channel(int64_t out_channel);
void set_format(const Format &format);
std::vector<int64_t> get_kernel_size() const;
std::vector<int64_t> get_stride() const;
std::vector<int64_t> get_dilation() const;
PadMode get_pad_mode() const;
std::vector<int64_t> get_pad() const;
int64_t get_mode() const;
int64_t get_group() const;
int64_t get_out_channel() const;
Format get_format() const;
};
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimConv2dPtr = std::shared_ptr<Conv2D>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_CONV2D_H_
#endif // MINDSPORE_CORE_OPS_CONV2D_H_

View File

@ -0,0 +1,199 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <map>
#include <vector>
#include <string>
#include <memory>
#include <set>
#include "ops/conv2d_transpose.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto conv2d_transpose_prim = primitive->cast<PrimConv2dTransposePtr>();
MS_EXCEPTION_IF_NULL(conv2d_transpose_prim);
auto prim_name = conv2d_transpose_prim->name();
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[3]->BuildShape(), prim_name);
return std::make_shared<abstract::Shape>(input_shape);
}
TypePtr Conv2dTransposeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInteger("conv2d_transpose_infer", input_args.size(), kEqual, 3, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
std::map<std::string, TypePtr> types;
types.emplace("doutput_dtye", input_args[0]->BuildType());
types.emplace("w_dtype", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return TypeIdToType(infer_type);
}
} // namespace
void Conv2dTranspose::Init(int64_t in_channel, int64_t out_channel, const std::vector<int64_t> &kernel_size,
int64_t mode, const PadMode &pad_mode, const std::vector<int64_t> &pad,
const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, int64_t group,
const Format &format, const std::vector<int64_t> &pad_list) {
set_in_channel(in_channel);
set_out_channel(out_channel);
set_kernel_size(kernel_size);
set_mode(mode);
set_pad(pad);
set_pad_mode(pad_mode);
set_stride(stride);
set_dilation(dilation);
set_group(group);
set_format(format);
set_pad_list(pad_list);
}
void Conv2dTranspose::set_in_channel(int64_t in_channel) {
AddAttr(kOutChannel, MakeValue(CheckAndConvertUtils::CheckInteger(kInChannel, in_channel, kGreaterThan, 0, name())));
}
void Conv2dTranspose::set_out_channel(int64_t out_channel) {
AddAttr(kOutChannel,
MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
}
void Conv2dTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
CheckAndConvertUtils::CheckInteger(kKernelSize, kernel_size.size(), kEqual, 2, name());
for (int64_t item : kernel_size) {
CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name());
}
AddAttr(kKernelSize, MakeValue(kernel_size));
}
void Conv2dTranspose::set_stride(const std::vector<int64_t> &stride) {
CheckAndConvertUtils::CheckInteger(kStride, stride.size(), kEqual, 2, name());
for (int64_t item : stride) {
CheckAndConvertUtils::CheckInteger(kStride, item, kGreaterEqual, 1, name());
}
AddAttr(kStride, MakeValue(stride));
}
void Conv2dTranspose::set_dilation(const std::vector<int64_t> &dilation) {
CheckAndConvertUtils::CheckInteger(kDilation, dilation.size(), kGreaterEqual, 2, name());
AddAttr(kDilation, MakeValue(dilation));
}
void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) {
std::vector<int64_t> pad = get_pad();
if (pad_mode == PAD) {
for (auto item : pad) {
CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, "zeros_list", 0, name());
}
} else {
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, name());
}
int64_t swi = pad_mode;
AddAttr(kPadMode, MakeValue(swi));
}
void Conv2dTranspose::set_pad(const std::vector<int64_t> &pad) {
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true)));
}
void Conv2dTranspose::set_mode(int64_t mode) {
AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
}
void Conv2dTranspose::set_group(int64_t group) {
AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
}
void Conv2dTranspose::set_format(const Format &format) {
int64_t f = format;
AddAttr(kFormat, MakeValue(f));
}
void Conv2dTranspose::set_pad_list(const std::vector<int64_t> &pad_list) {
CheckAndConvertUtils::CheckInteger(kPadList, pad_list.size(), kEqual, 4, name());
this->AddAttr(kPadList, MakeValue(pad_list));
}
int64_t Conv2dTranspose::get_in_channel() const {
auto value_ptr = GetAttr(kInChannel);
return GetValue<int64_t>(value_ptr);
}
int64_t Conv2dTranspose::get_out_channel() const {
auto value_ptr = GetAttr(kOutChannel);
return GetValue<int64_t>(value_ptr);
}
std::vector<int64_t> Conv2dTranspose::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> Conv2dTranspose::get_stride() const {
auto value_ptr = GetAttr(kStride);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> Conv2dTranspose::get_dilation() const {
auto value_ptr = GetAttr(kDilation);
return GetValue<std::vector<int64_t>>(value_ptr);
}
PadMode Conv2dTranspose::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode);
return PadMode(GetValue<int64_t>(value_ptr));
}
std::vector<int64_t> Conv2dTranspose::get_pad() const {
auto value_ptr = GetAttr(kPad);
return GetValue<std::vector<int64_t>>(value_ptr);
}
int64_t Conv2dTranspose::get_mode() const {
auto value_ptr = GetAttr(kMode);
return GetValue<int64_t>(value_ptr);
}
int64_t Conv2dTranspose::get_group() const {
auto value_ptr = GetAttr(kGroup);
return GetValue<int64_t>(value_ptr);
}
Format Conv2dTranspose::get_format() const {
auto value_ptr = GetAttr(kFormat);
return Format(GetValue<int64_t>(value_ptr));
}
std::vector<int64_t> Conv2dTranspose::get_pad_list() const {
auto value_ptr = GetAttr(kPadList);
return GetValue<std::vector<int64_t>>(value_ptr);
}
AbstractBasePtr Conv2dTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(Conv2dTransposeInferType(primitive, input_args),
Conv2dTransposeInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2dTranspose, prim::kPrimConv2DTranspose, Conv2dTransposeInfer);
REGISTER_PRIMITIVE_C(kNameConv2dTranspose, Conv2dTranspose);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,74 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CONV2D_TRANSPOSE_H_
#define MINDSPORE_CORE_OPS_CONV2D_TRANSPOSE_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameConv2dTranspose = "Conv2dTranspose";
class Conv2dTranspose : public PrimitiveC {
public:
Conv2dTranspose() : PrimitiveC(kNameConv2dTranspose) {
InitIOName({"out_backprop", "filter", "input_sizes"}, {"output"});
}
explicit Conv2dTranspose(const std::string k_name) : PrimitiveC(k_name) {
InitIOName({"out_backprop", "filter", "input_sizes"}, {"output"});
}
~Conv2dTranspose() = default;
MS_DECLARE_PARENT(Conv2dTranspose, PrimitiveC);
void Init(int64_t in_channel, int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode = 1,
const PadMode &pad_mode = VALID, const std::vector<int64_t> &pad = {0, 0, 0, 0},
const std::vector<int64_t> &stride = {1, 1}, const std::vector<int64_t> &dilation = {1, 1},
int64_t group = 1, const Format &format = NCHW, const std::vector<int64_t> &pad_list = {0, 0, 0, 0});
void set_in_channel(int64_t in_channel);
void set_out_channel(int64_t out_channel);
virtual void set_kernel_size(const std::vector<int64_t> &kernel_size);
void set_stride(const std::vector<int64_t> &stride);
virtual void set_dilation(const std::vector<int64_t> &dilation);
void set_pad_mode(const PadMode &pad_mode);
void set_pad(const std::vector<int64_t> &pad);
void set_mode(int64_t mode);
void set_group(int64_t group);
void set_format(const Format &format);
void set_pad_list(const std::vector<int64_t> &pad_list);
int64_t get_in_channel() const;
int64_t get_out_channel() const;
std::vector<int64_t> get_kernel_size() const;
std::vector<int64_t> get_stride() const;
std::vector<int64_t> get_dilation() const;
PadMode get_pad_mode() const;
std::vector<int64_t> get_pad() const;
int64_t get_mode() const;
int64_t get_group() const;
Format get_format() const;
std::vector<int64_t> get_pad_list() const;
};
AbstractBasePtr Conv2dTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimConv2dTransposePtr = std::shared_ptr<Conv2dTranspose>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CONV2D_TRANSPOSE_H_

Some files were not shown because too many files have changed in this diff Show More