forked from mindspore-Ecosystem/mindspore
add some ops
This commit is contained in:
parent
f9d9bba927
commit
d9be0c102d
|
@ -333,14 +333,13 @@ if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
|||
target_link_libraries(mindspore mindspore::pybind11_module)
|
||||
target_link_libraries(mindspore mindspore_gvar)
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive)
|
||||
elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
target_link_libraries(mindspore mindspore::pybind11_module)
|
||||
target_link_libraries(mindspore mindspore_gvar)
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore -Wl,-noall_load)
|
||||
else()
|
||||
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf
|
||||
mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
|
||||
else ()
|
||||
if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a)
|
||||
target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache)
|
||||
if(${ENABLE_IBVERBS} STREQUAL "ON")
|
||||
target_link_libraries(mindspore ibverbs rdmacm)
|
||||
|
|
|
@ -717,7 +717,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
|
|||
prim::kPrimMinimumGrad,
|
||||
prim::kPrimGkDropout,
|
||||
prim::kPrimDropoutGrad,
|
||||
prim::kPrimSoftMax,
|
||||
prim::kPrimSoftmax,
|
||||
prim::kPrimLayerNorm,
|
||||
prim::kPrimLayerNormGrad,
|
||||
#endif
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ir/manager.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
@ -1750,8 +1750,8 @@ void SessionBasic::RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<
|
|||
input_abstracts.emplace_back(abstract);
|
||||
}
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(node);
|
||||
if (prim->isa<PrimitiveC>()) {
|
||||
auto prim_c = prim->cast<std::shared_ptr<PrimitiveC>>();
|
||||
if (prim->isa<ops::PrimitiveC>()) {
|
||||
auto prim_c = prim->cast<std::shared_ptr<ops::PrimitiveC>>();
|
||||
MS_EXCEPTION_IF_NULL(prim_c);
|
||||
auto abstract = prim_c->Infer(input_abstracts);
|
||||
node->set_abstract(abstract);
|
||||
|
|
|
@ -8,16 +8,16 @@ endif()
|
|||
message("************ build core ***************")
|
||||
|
||||
file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"abstract/*.cc"
|
||||
"base/*.cc"
|
||||
"c_ops/*.cc"
|
||||
"ir/*.cc"
|
||||
"utils/*.cc"
|
||||
"load_mindir/*.cc"
|
||||
)
|
||||
"abstract/*.cc"
|
||||
"base/*.cc"
|
||||
"ops/*.cc"
|
||||
"ir/*.cc"
|
||||
"utils/*.cc"
|
||||
"load_mindir/*.cc"
|
||||
)
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes -DHAVE_SNPRINTF")
|
||||
add_compile_definitions(BUILDING_DLL)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes -DHAVE_SNPRINTF")
|
||||
add_compile_definitions(BUILDING_DLL)
|
||||
elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} \
|
||||
-Wuser-defined-warnings -Winconsistent-missing-override -Wno-delete-non-abstract-non-virtual-dtor")
|
||||
|
@ -28,5 +28,5 @@ add_library(mindspore_core STATIC ${CORE_SRC_LIST})
|
|||
target_link_libraries(mindspore_core PRIVATE mindspore_gvar)
|
||||
|
||||
if(USE_GLOG)
|
||||
target_link_libraries(mindspore_core PRIVATE mindspore::glog)
|
||||
target_link_libraries(mindspore_core PRIVATE mindspore::glog)
|
||||
endif()
|
||||
|
|
|
@ -91,7 +91,9 @@ inline const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelS
|
|||
inline const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet");
|
||||
|
||||
// Arrays
|
||||
inline const PrimitivePtr kPrimBroadcastTo = std::make_shared<Primitive>("BroadcastTo");
|
||||
inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array");
|
||||
inline const PrimitivePtr kPrimTopK = std::make_shared<Primitive>("TopK");
|
||||
inline const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar");
|
||||
inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape");
|
||||
inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map");
|
||||
|
@ -99,17 +101,25 @@ inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_
|
|||
inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
|
||||
inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
|
||||
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
|
||||
inline const PrimitivePtr kPrimUnsqueeze = std::make_shared<Primitive>("Unsqueeze");
|
||||
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
|
||||
inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
|
||||
inline const PrimitivePtr kPrimGatherD = std::make_shared<Primitive>("GatherD");
|
||||
inline const PrimitivePtr kPrimGather = std::make_shared<Primitive>(kGather);
|
||||
inline const PrimitivePtr kPrimGather = std::make_shared<Primitive>("Gather");
|
||||
inline const PrimitivePtr kPrimGatherND = std::make_shared<Primitive>("GatherND");
|
||||
inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared<Primitive>("SparseGatherV2");
|
||||
inline const PrimitivePtr kPrimSparseToDense = std::make_shared<Primitive>("SparseToDense");
|
||||
inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
|
||||
inline const PrimitivePtr kPrimStridedSlice = std::make_shared<Primitive>("StridedSlice");
|
||||
inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>("DynamicShape");
|
||||
inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup");
|
||||
inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad");
|
||||
inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
|
||||
inline const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax");
|
||||
inline const PrimitivePtr kPrimArgMin = std::make_shared<Primitive>("Argmin");
|
||||
inline const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack");
|
||||
inline const PrimitivePtr kPrimUnpack = std::make_shared<Primitive>("Unpack");
|
||||
inline const PrimitivePtr kPrimUnstack = std::make_shared<Primitive>("Unstack");
|
||||
inline const PrimitivePtr kPrimUnsortedSegmentMax = std::make_shared<Primitive>("UnsortedSegmentMax");
|
||||
inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum");
|
||||
inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
|
||||
|
@ -123,6 +133,7 @@ inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("Cac
|
|||
inline const PrimitivePtr kPrimDynamicAssign = std::make_shared<Primitive>("DynamicAssign");
|
||||
inline const PrimitivePtr kPrimPadAndShift = std::make_shared<Primitive>("PadAndShift");
|
||||
inline const PrimitivePtr kPrimSlice = std::make_shared<Primitive>("Slice");
|
||||
inline const PrimitivePtr kPrimSliceFusion = std::make_shared<Primitive>("SliceFusion");
|
||||
inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
|
||||
inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
|
||||
inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2");
|
||||
|
@ -145,16 +156,36 @@ inline const PrimitivePtr kPrimSequenceMask = std::make_shared<Primitive>("Seque
|
|||
inline const PrimitivePtr kPrimRange = std::make_shared<Primitive>("Range");
|
||||
inline const PrimitivePtr kPrimSpaceToBatchND = std::make_shared<Primitive>("SpaceToBatchND");
|
||||
inline const PrimitivePtr kPrimBatchToSpaceND = std::make_shared<Primitive>("BatchToSpaceND");
|
||||
inline const PrimitivePtr kPrimDepthToSpace = std::make_shared<Primitive>("DepthToSpace");
|
||||
inline const PrimitivePtr kPrimBatchToSpace = std::make_shared<Primitive>("BatchToSpace");
|
||||
inline const PrimitivePtr kPrimSpaceToBatch = std::make_shared<Primitive>("SpaceToBatch");
|
||||
inline const PrimitivePtr kPrimScatterNd = std::make_shared<Primitive>("ScatterNd");
|
||||
inline const PrimitivePtr kPrimConstantOfShape = std::make_shared<Primitive>("ConstantOfShape");
|
||||
inline const PrimitivePtr kPrimSquaredDifference = std::make_shared<Primitive>("SquaredDifference");
|
||||
inline const PrimitivePtr kPrimReverseV2 = std::make_shared<Primitive>("ReverseV2");
|
||||
inline const PrimitivePtr kPrimReverseSequence = std::make_shared<Primitive>("ReverseSequence");
|
||||
inline const PrimitivePtr kPrimRank = std::make_shared<Primitive>("Rank");
|
||||
inline const PrimitivePtr kPrimResizeBilinear = std::make_shared<Primitive>("ResizeBilinear");
|
||||
|
||||
// NN
|
||||
inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam");
|
||||
inline const PrimitivePtr kPrimAudioSpectrogram = std::make_shared<Primitive>("AudioSpectrogram");
|
||||
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
|
||||
inline const PrimitivePtr kPrimSoftMax = std::make_shared<Primitive>("Softmax");
|
||||
inline const PrimitivePtr kPrimCrop = std::make_shared<Primitive>("Crop");
|
||||
inline const PrimitivePtr kPrimFlattenGrad = std::make_shared<Primitive>("FlattenGrad");
|
||||
inline const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax");
|
||||
inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropy = std::make_shared<Primitive>("SparseSoftmaxCrossEntropy");
|
||||
inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax");
|
||||
inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad");
|
||||
inline const PrimitivePtr kPrimLstm = std::make_shared<Primitive>("Lstm");
|
||||
inline const PrimitivePtr kPrimTan = std::make_shared<Primitive>("Tan");
|
||||
inline const PrimitivePtr kPrimAtan = std::make_shared<Primitive>("Atan");
|
||||
inline const PrimitivePtr kPrimAsin = std::make_shared<Primitive>("Asin");
|
||||
inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh");
|
||||
inline const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad");
|
||||
inline const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling");
|
||||
inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad");
|
||||
inline const PrimitivePtr kPrimROIPooling = std::make_shared<Primitive>("ROIPooling");
|
||||
inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool");
|
||||
inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad");
|
||||
inline const PrimitivePtr kPrimMaxPoolWithArgmax = std::make_shared<Primitive>("MaxPoolWithArgmax");
|
||||
|
@ -168,6 +199,9 @@ inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("Fu
|
|||
inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
|
||||
inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared<Primitive>("FusedBatchNormEx");
|
||||
inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
|
||||
inline const PrimitivePtr kPrimFullConnection = std::make_shared<Primitive>("FullConnection");
|
||||
inline const PrimitivePtr kPrimConv2DTranspose = std::make_shared<Primitive>("Conv2DTranspose");
|
||||
inline const PrimitivePtr kPrimGroupConv2DGradInput = std::make_shared<Primitive>("GroupConv2DGradInput");
|
||||
inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad");
|
||||
inline const PrimitivePtr kPrimFusedBatchNormGradEx = std::make_shared<Primitive>("FusedBatchNormGradEx");
|
||||
inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
|
||||
|
@ -179,21 +213,34 @@ inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>
|
|||
inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter");
|
||||
inline const PrimitivePtr kPrimConv3DBackpropInput = std::make_shared<Primitive>("Conv3DBackpropInput");
|
||||
inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared<Primitive>("Conv3DBackpropFilter");
|
||||
inline const PrimitivePtr kPrimCustomNormalize = std::make_shared<Primitive>("CustomNormalize");
|
||||
inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative");
|
||||
inline const PrimitivePtr kPrimCTCGreedyDecoder = std::make_shared<Primitive>("CTCGreedyDecoder");
|
||||
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter =
|
||||
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter");
|
||||
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput =
|
||||
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput");
|
||||
inline const PrimitivePtr kPrimDetectionPostProcess = std::make_shared<Primitive>("DetectionPostProcess");
|
||||
inline const PrimitivePtr kPrimBiasAdd = std::make_shared<Primitive>("BiasAdd");
|
||||
inline const PrimitivePtr kPrimBiasGrad = std::make_shared<Primitive>("BiasGrad");
|
||||
inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad");
|
||||
inline const PrimitivePtr kPrimBiasSubGrad = std::make_shared<Primitive>("BiasSubGrad");
|
||||
inline const PrimitivePtr kPrimBinaryCrossEntropy = std::make_shared<Primitive>("BinaryCrossEntropy");
|
||||
inline const PrimitivePtr kPrimBinaryCrossEntropyGrad = std::make_shared<Primitive>("BinaryCrossEntropyGrad");
|
||||
inline const PrimitivePtr kPrimSmoothL1Loss = std::make_shared<Primitive>("SmoothL1Loss");
|
||||
inline const PrimitivePtr kPrimSmoothL1LossGrad = std::make_shared<Primitive>("SmoothL1LossGrad");
|
||||
inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits =
|
||||
std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits");
|
||||
inline const PrimitivePtr kPrimSigmoidCrossEntropyWithLogits =
|
||||
std::make_shared<Primitive>("SigmoidCrossEntropyWithLogits");
|
||||
inline const PrimitivePtr kPrimSigmoidCrossEntropyWithLogitsGrad =
|
||||
std::make_shared<Primitive>("SigmoidCrossEntropyWithLogitsGrad");
|
||||
inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits =
|
||||
std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits");
|
||||
inline const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum");
|
||||
inline const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum");
|
||||
inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm");
|
||||
inline const PrimitivePtr kPrimLrn = std::make_shared<Primitive>("Lrn");
|
||||
inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad");
|
||||
inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop");
|
||||
inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
|
||||
|
@ -204,18 +251,22 @@ inline const PrimitivePtr kPrimDropout = std::make_shared<Primitive>("Dropout");
|
|||
inline const PrimitivePtr kPrimUniformReal = std::make_shared<Primitive>("UniformReal");
|
||||
inline const PrimitivePtr kPrimCudnnUniformReal = std::make_shared<Primitive>("CudnnUniformReal");
|
||||
inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
|
||||
inline const PrimitivePtr kPrimGeLU = std::make_shared<Primitive>("Gelu");
|
||||
inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
|
||||
inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");
|
||||
inline const PrimitivePtr kPrimFastGelu = std::make_shared<Primitive>("FastGelu");
|
||||
inline const PrimitivePtr kPrimFastGeluGrad = std::make_shared<Primitive>("FastGeluGrad");
|
||||
inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
|
||||
inline const PrimitivePtr kPrimElu = std::make_shared<Primitive>("Elu");
|
||||
inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6");
|
||||
inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
|
||||
inline const PrimitivePtr kPrimPRelu = std::make_shared<Primitive>("PReLU");
|
||||
inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
|
||||
inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike");
|
||||
inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
|
||||
inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
|
||||
inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
|
||||
inline const PrimitivePtr kPrimFakeQuantWithMinMaxVars = std::make_shared<Primitive>("FakeQuantWithMinMaxVars");
|
||||
inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp");
|
||||
inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("SparseApplyFtrl");
|
||||
inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad");
|
||||
|
@ -224,6 +275,8 @@ inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive
|
|||
inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD");
|
||||
inline const PrimitivePtr kPrimClipByNormNoDivSum = std::make_shared<Primitive>("ClipByNormNoDivSum");
|
||||
inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorMove");
|
||||
inline const PrimitivePtr kPrimL2Normalize = std::make_shared<Primitive>("L2Normalize");
|
||||
inline const PrimitivePtr kPrimCustomExtractFeatures = std::make_shared<Primitive>("CustomExtractFeatures");
|
||||
|
||||
// Comm ops
|
||||
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||
|
@ -239,6 +292,12 @@ inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGathe
|
|||
inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter");
|
||||
inline const PrimitivePtr kPrimMemCpyAsync = std::make_shared<Primitive>("memcpy_async");
|
||||
inline const PrimitivePtr kPrimFill = std::make_shared<Primitive>("Fill");
|
||||
// Quant ops
|
||||
inline const PrimitivePtr kPrimBatchNormFold = std::make_shared<Primitive>("BatchNormFold");
|
||||
inline const PrimitivePtr kPrimFakeQuantWithMinMaxVarsPerChannel =
|
||||
std::make_shared<Primitive>("FakeQuantWithMinMaxVarsPerChannel");
|
||||
// Control ops
|
||||
inline const PrimitivePtr kPrimMerge = std::make_shared<Primitive>("Merge");
|
||||
// RowTensor
|
||||
inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared<Primitive>("MakeRowTensor");
|
||||
inline const PrimitivePtr kPrimRowTensorGetValues = std::make_shared<Primitive>("RowTensorGetValues");
|
||||
|
@ -251,12 +310,22 @@ inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitiv
|
|||
inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
|
||||
inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
|
||||
|
||||
// TensorList
|
||||
inline const PrimitivePtr kPrimTensorListFromTensor = std::make_shared<Primitive>("TensorListFromTensor");
|
||||
inline const PrimitivePtr kPrimTensorListReserve = std::make_shared<Primitive>("TensorListReserve");
|
||||
inline const PrimitivePtr kPrimTensorListStack = std::make_shared<Primitive>("TensorListStack");
|
||||
inline const PrimitivePtr kPrimTensorListSetItem = std::make_shared<Primitive>("TensorListSetItem");
|
||||
|
||||
// Maths
|
||||
inline const PrimitivePtr kPrimCeil = std::make_shared<Primitive>("Ceil");
|
||||
inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");
|
||||
inline const PrimitivePtr kPrimAdd = std::make_shared<Primitive>("Add");
|
||||
inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
|
||||
inline const PrimitivePtr kPrimMatrixDiag = std::make_shared<Primitive>("MatrixDiag");
|
||||
inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
|
||||
inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");
|
||||
inline const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad");
|
||||
inline const PrimitivePtr kPrimReduce = std::make_shared<Primitive>("Reduce");
|
||||
inline const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean");
|
||||
inline const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum");
|
||||
inline const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll");
|
||||
|
@ -264,6 +333,8 @@ inline const PrimitivePtr kPrimReduceAny = std::make_shared<Primitive>("ReduceAn
|
|||
inline const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax");
|
||||
inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin");
|
||||
inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg");
|
||||
inline const PrimitivePtr kPrimSin = std::make_shared<Primitive>("Sin");
|
||||
inline const PrimitivePtr kPrimCos = std::make_shared<Primitive>("Cos");
|
||||
inline const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub");
|
||||
inline const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
|
||||
inline const PrimitivePtr kPrimDiv = std::make_shared<Primitive>("Div");
|
||||
|
@ -279,6 +350,7 @@ inline const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscala
|
|||
inline const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd");
|
||||
inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub");
|
||||
inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow");
|
||||
inline const PrimitivePtr kPrimPower = std::make_shared<Primitive>("Power");
|
||||
inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv");
|
||||
inline const PrimitivePtr kPrimFloorDiv = std::make_shared<Primitive>("FloorDiv");
|
||||
inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
|
||||
|
@ -292,12 +364,13 @@ inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log");
|
|||
inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt");
|
||||
inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV");
|
||||
inline const PrimitivePtr kPrimLinSpace = std::make_shared<Primitive>("LinSpace");
|
||||
inline const PrimitivePtr kPrimNonMaxSuppression = std::make_shared<Primitive>("NonMaxSuppression");
|
||||
inline const PrimitivePtr kPrimSign = std::make_shared<Primitive>("Sign");
|
||||
inline const PrimitivePtr kPrimSquaredDifference = std::make_shared<Primitive>("SquaredDifference");
|
||||
inline const PrimitivePtr kPrimAsin = std::make_shared<Primitive>("Asin");
|
||||
inline const PrimitivePtr kPrimACos = std::make_shared<Primitive>("ACos");
|
||||
inline const PrimitivePtr kPrimAsinGrad = std::make_shared<Primitive>("AsinGrad");
|
||||
inline const PrimitivePtr kPrimACosGrad = std::make_shared<Primitive>("ACosGrad");
|
||||
inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod");
|
||||
inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where");
|
||||
|
||||
// Statements
|
||||
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");
|
||||
|
@ -323,6 +396,7 @@ inline const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared<Primitive>(
|
|||
inline const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index");
|
||||
|
||||
// Debug ops
|
||||
inline const PrimitivePtr kPrimAssert = std::make_shared<Primitive>("Assert");
|
||||
inline const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
|
||||
inline const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary");
|
||||
inline const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
|
||||
|
@ -349,6 +423,13 @@ inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
|
|||
inline const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
|
||||
inline const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
|
||||
inline const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat");
|
||||
inline const PrimitivePtr kPrimLshProjection = std::make_shared<Primitive>("LshProjection");
|
||||
inline const PrimitivePtr kPrimHashtableLookup = std::make_shared<Primitive>("HashtableLookup");
|
||||
inline const PrimitivePtr kPrimCustomPredict = std::make_shared<Primitive>("CustomPredict");
|
||||
inline const PrimitivePtr kPrimStack = std::make_shared<Primitive>("Stack");
|
||||
inline const PrimitivePtr kPrimPriorBox = std::make_shared<Primitive>("PriorBox");
|
||||
inline const PrimitivePtr kPrimQuantDTypeCast = std::make_shared<Primitive>("QuantDTypeCast");
|
||||
inline const PrimitivePtr kPrimWhile = std::make_shared<Primitive>("While");
|
||||
|
||||
// Structures
|
||||
inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
|
||||
|
@ -371,7 +452,7 @@ inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_
|
|||
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
|
||||
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
|
||||
|
||||
// Other primitive not used by backend but used in core;
|
||||
// Other primitve not used by backend but used in core;
|
||||
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
|
||||
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
|
||||
|
||||
|
@ -382,6 +463,44 @@ inline const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict
|
|||
// GraphKernel ops
|
||||
inline const PrimitivePtr kPrimInplaceAssign = std::make_shared<Primitive>("InplaceAssign");
|
||||
|
||||
// Only used in lite
|
||||
inline const PrimitivePtr kPrimLeakyRelu = std::make_shared<Primitive>("LeakyRelu");
|
||||
inline const PrimitivePtr kPrimConstant = std::make_shared<Primitive>("Constant");
|
||||
inline const PrimitivePtr kPrimLocalResponseNormalization = std::make_shared<Primitive>("LocalResponseNormalization");
|
||||
inline const PrimitivePtr kPrimFftReal = std::make_shared<Primitive>("FftReal");
|
||||
inline const PrimitivePtr kPrimMfcc = std::make_shared<Primitive>("Mfcc");
|
||||
inline const PrimitivePtr kPrimRfft = std::make_shared<Primitive>("Rfft");
|
||||
inline const PrimitivePtr kPrimFftImag = std::make_shared<Primitive>("FftImag");
|
||||
inline const PrimitivePtr kPrimSkipGram = std::make_shared<Primitive>("SkipGram");
|
||||
inline const PrimitivePtr kPrimConv2DFusion = std::make_shared<Primitive>("Conv2DFusion");
|
||||
inline const PrimitivePtr kPrimConv2dTransposeFusion = std::make_shared<Primitive>("Conv2dTransposeFusion");
|
||||
inline const PrimitivePtr kPrimDepthWiseConv2DFusion = std::make_shared<Primitive>("DepthWiseConv2DFusion");
|
||||
inline const PrimitivePtr kPrimAddFusion = std::make_shared<Primitive>("AddFusion");
|
||||
inline const PrimitivePtr kPrimScaleFusion = std::make_shared<Primitive>("ScaleFusion");
|
||||
inline const PrimitivePtr kPrimSubFusion = std::make_shared<Primitive>("SubFusion");
|
||||
inline const PrimitivePtr kPrimMulFusion = std::make_shared<Primitive>("MulFusion");
|
||||
inline const PrimitivePtr kPrimSigmoid = std::make_shared<Primitive>("Sigmoid");
|
||||
inline const PrimitivePtr kPrimClip = std::make_shared<Primitive>("Clip");
|
||||
inline const PrimitivePtr kPrimHardTanh = std::make_shared<Primitive>("HardTanh");
|
||||
inline const PrimitivePtr kPrimDepthWiseConv2DTransposeFusion =
|
||||
std::make_shared<Primitive>("DepthWiseConv2DTransposeFusion");
|
||||
inline const PrimitivePtr kPrimArgMinFusion = std::make_shared<Primitive>("ArgMinFusion");
|
||||
inline const PrimitivePtr kPrimArgMaxFusion = std::make_shared<Primitive>("ArgMaxFusion");
|
||||
inline const PrimitivePtr kPrimSpaceToDepth = std::make_shared<Primitive>("SpaceToDepth");
|
||||
inline const PrimitivePtr kPrimPadFusion = std::make_shared<Primitive>("PadFusion");
|
||||
inline const PrimitivePtr kPrimPowFusion = std::make_shared<Primitive>("PowFusion");
|
||||
inline const PrimitivePtr kPrimResize = std::make_shared<Primitive>("Resize");
|
||||
inline const PrimitivePtr kPrimConv2dTranspose = std::make_shared<Primitive>("Conv2dTranspose");
|
||||
inline const PrimitivePtr kPrimArgMinWithValue = std::make_shared<Primitive>("ArgMinWithValue");
|
||||
inline const PrimitivePtr kPrimIf = std::make_shared<Primitive>("If");
|
||||
inline const PrimitivePtr kPrimAvgPoolFusion = std::make_shared<Primitive>("AvgPoolFusion");
|
||||
inline const PrimitivePtr kPrimMaxPoolFusion = std::make_shared<Primitive>("MaxPoolFusion");
|
||||
inline const PrimitivePtr kPrimActivation = std::make_shared<Primitive>("Activation");
|
||||
inline const PrimitivePtr kPrimTopKFusion = std::make_shared<Primitive>("TopKFusion");
|
||||
inline const PrimitivePtr kPrimTileFusion = std::make_shared<Primitive>("TileFusion");
|
||||
inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("ReduceFusion");
|
||||
inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("LayerNormFusion");
|
||||
|
||||
class DoSignaturePrimitive : public Primitive {
|
||||
public:
|
||||
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "c_ops/abs.h"
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameAbs, Abs);
|
||||
} // namespace mindspore
|
|
@ -1,51 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/apply_momentum.h"
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
void ApplyMomentum::Init(bool use_nesterov, bool use_locking, float gradient_scale) {
|
||||
this->set_use_nesterov(use_nesterov);
|
||||
this->set_use_locking(use_locking);
|
||||
this->set_gradient_scale(gradient_scale);
|
||||
}
|
||||
|
||||
void ApplyMomentum::set_use_nesterov(bool use_nesterov) { this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); }
|
||||
|
||||
void ApplyMomentum::set_use_locking(bool use_locking) { this->AddAttr(kUseLocking, MakeValue(use_locking)); }
|
||||
|
||||
void ApplyMomentum::set_gradient_scale(float gradient_scale) {
|
||||
this->AddAttr(kGradientScale, MakeValue(gradient_scale));
|
||||
}
|
||||
|
||||
bool ApplyMomentum::get_use_nesterov() const {
|
||||
auto value_ptr = GetAttr(kUseNesterov);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
bool ApplyMomentum::get_use_locking() const {
|
||||
auto value_ptr = GetAttr(kUseLocking);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
float ApplyMomentum::get_gradient_scale() {
|
||||
auto value_ptr = GetAttr(kGradientScale);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum);
|
||||
} // namespace mindspore
|
|
@ -1,54 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/audio_spectrogram.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
void AudioSpectrogram::set_window_size(const int64_t &window_size) {
|
||||
this->AddAttr(kWindowSize, MakeValue(window_size));
|
||||
}
|
||||
int64_t AudioSpectrogram::get_window_size() const {
|
||||
auto value_ptr = GetAttr(kWindowSize);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void AudioSpectrogram::set_stride(const int64_t &stride) { this->AddAttr(kStride, MakeValue(stride)); }
|
||||
int64_t AudioSpectrogram::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void AudioSpectrogram::set_mag_square(const bool &mag_square) { this->AddAttr(kMagSquare, MakeValue(mag_square)); }
|
||||
bool AudioSpectrogram::get_mag_square() const {
|
||||
auto value_ptr = GetAttr(kMagSquare);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
void AudioSpectrogram::Init(const int64_t &window_size, const int64_t &stride, const bool &mag_square) {
|
||||
this->set_window_size(window_size);
|
||||
this->set_stride(stride);
|
||||
this->set_mag_square(mag_square);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAudioSpectrogram, AudioSpectrogram);
|
||||
} // namespace mindspore
|
|
@ -1,59 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "c_ops/batch_norm.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
void BatchNorm::Init(bool is_training, float epsilon, const Format &format) {
|
||||
set_is_training(is_training);
|
||||
set_epsilon(epsilon);
|
||||
set_format(format);
|
||||
}
|
||||
|
||||
void BatchNorm::set_is_training(bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); }
|
||||
|
||||
void BatchNorm::set_epsilon(float epsilon) {
|
||||
CheckAndConvertUtils::CheckInRange(kEpsilon, epsilon, kIncludeBoth, {0.0, 1.0}, this->name());
|
||||
this->AddAttr(kEpsilon, MakeValue(epsilon));
|
||||
}
|
||||
|
||||
void BatchNorm::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
this->AddAttr(kFormat, MakeValue(f));
|
||||
}
|
||||
|
||||
bool BatchNorm::get_is_trainging() {
|
||||
auto value_ptr = GetAttr(kIsTraining);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
float BatchNorm::get_epsilon() {
|
||||
auto value_ptr = GetAttr(kEpsilon);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
Format BatchNorm::get_format() const {
|
||||
auto value_ptr = GetAttr(kFormat);
|
||||
return Format(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameBatchNorm, BatchNorm);
|
||||
} // namespace mindspore
|
|
@ -1,21 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/batch_norm_fold.h"
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameBatchNormFold, BatchNormFold);
|
||||
} // namespace mindspore
|
|
@ -1,31 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/binary_cross_entropy_grad.h"
|
||||
|
||||
namespace mindspore {
|
||||
void BinaryCrossEntropyGrad::Init(const std::string &reduction) { set_reduction(reduction); }
|
||||
|
||||
void BinaryCrossEntropyGrad::set_reduction(const std::string &reduction) {
|
||||
CheckAndConvertUtils::CheckString(kReduction, reduction, {"none", "mean", "sum"}, name());
|
||||
this->AddAttr(kReduction, MakeValue(reduction));
|
||||
}
|
||||
std::string BinaryCrossEntropyGrad::get_reduction() const {
|
||||
auto value_ptr = GetAttr(kReduction);
|
||||
return GetValue<std::string>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameBinaryCrossEntropyGrad, BinaryCrossEntropyGrad);
|
||||
} // namespace mindspore
|
|
@ -1,42 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/broadcast.h"
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
void Broadcast::Init(int64_t root_rank, const std::string &group) {
|
||||
this->set_root_rank(root_rank);
|
||||
this->set_group(group);
|
||||
}
|
||||
void Broadcast::set_root_rank(int64_t root_rank) { this->AddAttr(kKeepProb, MakeValue(root_rank)); }
|
||||
|
||||
void Broadcast::set_group(const std::string &group) {
|
||||
CheckAndConvertUtils::CheckString(kGroup, group, {"hccl_world_group", "hccl_world_group"}, this->name());
|
||||
this->AddAttr(kGroup, MakeValue(group));
|
||||
}
|
||||
int64_t Broadcast::get_root_rank() {
|
||||
auto value_ptr = this->GetAttr(kRootRank);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
std::string Broadcast::get_group() const {
|
||||
auto value_ptr = this->GetAttr(kGroup);
|
||||
return GetValue<std::string>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameBroadcast, Broadcast);
|
||||
} // namespace mindspore
|
|
@ -1,21 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/ceil.h"
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameCeil, Ceil);
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/cos.h"
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameCos, Cos);
|
||||
}
|
|
@ -1,44 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/custom_predict.h"
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
void CustomPredict::Init(int64_t outputNum, float weight_threshold) {
|
||||
this->set_outputNum(outputNum);
|
||||
this->set_weight_threshold(weight_threshold);
|
||||
}
|
||||
|
||||
void CustomPredict::set_outputNum(int64_t outputNum) { this->AddAttr(kOutputNum, MakeValue(outputNum)); }
|
||||
|
||||
int64_t CustomPredict::get_outputNum() const {
|
||||
auto value_ptr = this->GetAttr(kOutputNum);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void CustomPredict::set_weight_threshold(float weight_threshold) {
|
||||
this->AddAttr(kWeightThreshold, MakeValue(weight_threshold));
|
||||
}
|
||||
|
||||
float CustomPredict::get_weight_threshold() const {
|
||||
auto value_ptr = this->GetAttr(kWeightThreshold);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameCustomPredict, CustomPredict);
|
||||
} // namespace mindspore
|
|
@ -1,21 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/div.h"
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameDiv, Div);
|
||||
} // namespace mindspore
|
|
@ -1,20 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "c_ops/equal.h"
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameEqual, Equal);
|
||||
} // namespace mindspore
|
|
@ -1,20 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "c_ops/exp.h"
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameExp, Exp);
|
||||
} // namespace mindspore
|
|
@ -1,44 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/fake_quant_with_min_max_vars.h"
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
void FakeQuantWithMinMaxVars::Init(const bool &narrow_range, int64_t num_bits) {
|
||||
this->set_narrow_range(narrow_range);
|
||||
this->set_num_bits(num_bits);
|
||||
}
|
||||
|
||||
void FakeQuantWithMinMaxVars::set_narrow_range(const bool &narrow_range) {
|
||||
this->AddAttr(kNarrowRange, MakeValue(narrow_range));
|
||||
}
|
||||
|
||||
bool FakeQuantWithMinMaxVars::get_narrow_range() const {
|
||||
auto value_ptr = this->GetAttr(kNarrowRange);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void FakeQuantWithMinMaxVars::set_num_bits(int64_t num_bits) { this->AddAttr(kNumBits, MakeValue(num_bits)); }
|
||||
|
||||
int64_t FakeQuantWithMinMaxVars::get_num_bits() const {
|
||||
auto value_ptr = this->GetAttr(kNumBits);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameFakeQuantWithMinMaxVars, FakeQuantWithMinMaxVars);
|
||||
} // namespace mindspore
|
|
@ -1,22 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/fft_imag.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameFftImag, FftImag);
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/flatten_grad.h"
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameFlattenGrad, FlattenGrad);
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/hashtable_lookup.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameHashtableLookup, HashtableLookup);
|
||||
} // namespace mindspore
|
|
@ -1,21 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/less.h"
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameLess, Less);
|
||||
}
|
|
@ -1,20 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "c_ops/less_equal.h"
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameLessEqual, LessEqual);
|
||||
} // namespace mindspore
|
|
@ -1,65 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/local_response_normalization.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
void LocalResponseNormalization::set_depth_radius(const int64_t &depth_radius) {
|
||||
this->AddAttr(kDepthRadius, MakeValue(depth_radius));
|
||||
}
|
||||
|
||||
int64_t LocalResponseNormalization::get_depth_radius() const {
|
||||
auto value_ptr = GetAttr(kDepthRadius);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void LocalResponseNormalization::set_bias(const float &bias) { this->AddAttr(kBias, MakeValue(bias)); }
|
||||
|
||||
float LocalResponseNormalization::get_bias() const {
|
||||
auto value_ptr = GetAttr(kBias);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
void LocalResponseNormalization::set_alpha(const float &alpha) { this->AddAttr(kAlpha, MakeValue(alpha)); }
|
||||
|
||||
float LocalResponseNormalization::get_alpha() const {
|
||||
auto value_ptr = GetAttr(kAlpha);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
void LocalResponseNormalization::set_beta(const float &beta) { this->AddAttr(kBeta, MakeValue(beta)); }
|
||||
|
||||
float LocalResponseNormalization::get_beta() const {
|
||||
auto value_ptr = GetAttr(kBeta);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
void LocalResponseNormalization::Init(const int64_t &depth_radius, const float &bias, const float &alpha,
|
||||
const float &beta) {
|
||||
this->set_depth_radius(depth_radius);
|
||||
this->set_bias(bias);
|
||||
this->set_alpha(alpha);
|
||||
this->set_beta(beta);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameLocalResponseNormalization, LocalResponseNormalization);
|
||||
} // namespace mindspore
|
|
@ -1,21 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/log.h"
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameLog, Log);
|
||||
}
|
|
@ -1,20 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "c_ops/logical_not.h"
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameLogicalNot, LogicalNot);
|
||||
} // namespace mindspore
|
|
@ -1,20 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "c_ops/logical_or.h"
|
||||
|
||||
namespace mindspore {
|
||||
REGISTER_PRIMITIVE_C(kNameLogicalOr, LogicalOr);
|
||||
} // namespace mindspore
|
|
@ -1,82 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/lstm.h"
|
||||
|
||||
namespace mindspore {
|
||||
void LSTM::set_input_size(const int64_t &input_size) {
|
||||
CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name());
|
||||
AddAttr(kInput_size, MakeValue(input_size));
|
||||
}
|
||||
int64_t LSTM::get_input_size() const {
|
||||
auto value_ptr = this->GetAttr(kInput_size);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void LSTM::set_hidden_size(const int64_t &hidden_size) {
|
||||
CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name());
|
||||
AddAttr(kHidden_size, MakeValue(hidden_size));
|
||||
}
|
||||
int64_t LSTM::get_hidden_size() const {
|
||||
auto value_ptr = this->GetAttr(kHidden_size);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void LSTM::set_num_layers(const int64_t &num_layers) {
|
||||
CheckAndConvertUtils::CheckInteger(kNum_layers, num_layers, kGreaterThan, 0, this->name());
|
||||
AddAttr(kNum_layers, MakeValue(kNum_layers));
|
||||
}
|
||||
int64_t LSTM::get_num_layers() const {
|
||||
auto value_ptr = this->GetAttr(kNum_layers);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void LSTM::set_has_bias(const bool &has_bias) { AddAttr(kHasBias, MakeValue(has_bias)); }
|
||||
bool LSTM::get_has_bias() const {
|
||||
auto value_ptr = this->GetAttr(kHasBias);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
void LSTM::set_dropout(const float &dropout) {
|
||||
CheckAndConvertUtils::CheckInRange(kDropout, dropout, kIncludeBoth, {0, 1}, this->name());
|
||||
AddAttr(kDropout, MakeValue(dropout));
|
||||
}
|
||||
float LSTM::get_dropout() const {
|
||||
auto value_ptr = this->GetAttr(kDropout);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
void LSTM::set_bidirectional(const bool &bidirectional) { AddAttr(kBidirectional, MakeValue(bidirectional)); }
|
||||
bool LSTM::get_bidirectional() const {
|
||||
auto value_ptr = this->GetAttr(kBidirectional);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
void LSTM::set_num_directions(const int64_t &num_directions) { AddAttr(kNumDirections, MakeValue(num_directions)); }
|
||||
int64_t LSTM::get_num_directions() const {
|
||||
auto value_ptr = this->GetAttr(kNumDirections);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void LSTM::Init(const int64_t &input_size, const int64_t &hidden_size, const int64_t &num_layers, const bool &has_bias,
|
||||
const float &dropout, const bool &bidirectional) {
|
||||
this->set_input_size(input_size);
|
||||
this->set_hidden_size(hidden_size);
|
||||
this->set_num_layers(num_layers);
|
||||
this->set_has_bias(has_bias);
|
||||
this->set_dropout(dropout);
|
||||
this->set_bidirectional(bidirectional);
|
||||
if (bidirectional) {
|
||||
this->set_num_directions(2);
|
||||
} else {
|
||||
this->set_num_directions(1);
|
||||
}
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameLSTM, LSTM);
|
||||
} // namespace mindspore
|
|
@ -25,7 +25,7 @@
|
|||
#include <utility>
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/shape_utils.h"
|
||||
|
@ -676,7 +676,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
|
|||
const std::string &node_type = node_proto.op_type();
|
||||
|
||||
std::shared_ptr<Primitive> prim;
|
||||
auto op_primc_fns = OpPrimCRegister::GetInstance().GetPrimCMap();
|
||||
auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
|
||||
if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
|
||||
prim = op_primc_fns[node_type]();
|
||||
} else {
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/abs.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto abs_prim = primitive->cast<PrimAbsPtr>();
|
||||
MS_EXCEPTION_IF_NULL(abs_prim);
|
||||
auto prim_name = abs_prim->name();
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("input_x", input_args[0]->BuildType());
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return TypeIdToType(infer_type);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr AbsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Abs, prim::kPrimAbs, AbsInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameAbs, Abs);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,13 +14,17 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_ABS_H_
|
||||
#define MINDSPORE_CORE_C_OPS_ABS_H_
|
||||
#include "c_ops/primitive_c.h"
|
||||
#ifndef MINDSPORE_CORE_OPS_ABS_H_
|
||||
#define MINDSPORE_CORE_OPS_ABS_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAbs = "Abs";
|
||||
class Abs : public PrimitiveC {
|
||||
public:
|
||||
|
@ -29,6 +33,10 @@ class Abs : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(Abs, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
AbstractBasePtr AbsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAbsPtr = std::shared_ptr<Abs>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_ABS_H_
|
||||
#endif // MINDSPORE_CORE_OPS_ABS_H_
|
|
@ -0,0 +1,88 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/adam.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto Adam_prim = primitive->cast<PrimAdamPtr>();
|
||||
MS_EXCEPTION_IF_NULL(Adam_prim);
|
||||
auto prim_name = Adam_prim->name();
|
||||
|
||||
// infer shape
|
||||
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShape("var_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShape("m_shape", input_args[1]->GetShapeTrack(), prim_name);
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[2]->GetShapeTrack(), prim_name);
|
||||
auto grad_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("grad_shape", input_args[9]->GetShapeTrack(), prim_name);
|
||||
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name);
|
||||
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name);
|
||||
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name);
|
||||
|
||||
// infer type
|
||||
auto var_type = input_args[0]->BuildType();
|
||||
auto m_type = input_args[1]->BuildType();
|
||||
auto v_type = input_args[2]->BuildType();
|
||||
auto grad_type = input_args[9]->BuildType();
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("grad_type", grad_type, common_valid_types, prim_name);
|
||||
|
||||
auto infer_var_type = var_type->cast<TensorTypePtr>()->element();
|
||||
auto infer_m_type = m_type->cast<TensorTypePtr>()->element();
|
||||
auto infer_v_type = v_type->cast<TensorTypePtr>()->element();
|
||||
// auto infer_grad_type = grad_type->cast<TensorTypePtr>()->element();
|
||||
auto output0 = std::make_shared<abstract::AbstractTensor>(infer_var_type, var_shape);
|
||||
auto output1 = std::make_shared<abstract::AbstractTensor>(infer_m_type, m_shape);
|
||||
auto output2 = std::make_shared<abstract::AbstractTensor>(infer_v_type, v_shape);
|
||||
AbstractBasePtrList output = {output0, output1, output2};
|
||||
return std::make_shared<abstract::AbstractTuple>(output);
|
||||
}
|
||||
} // namespace
|
||||
void Adam::Init(const bool use_locking, const bool use_nesterov) {
|
||||
this->set_use_locking(use_locking);
|
||||
this->set_use_nesterov(use_nesterov);
|
||||
}
|
||||
|
||||
void Adam::set_use_locking(const bool use_locking) { this->AddAttr(kUseLocking, MakeValue(use_locking)); }
|
||||
|
||||
void Adam::set_use_nesterov(const bool use_nesterov) { this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); }
|
||||
|
||||
bool Adam::get_use_locking() const {
|
||||
auto value_ptr = GetAttr(kUseLocking);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
bool Adam::get_use_nesterov() const {
|
||||
auto value_ptr = GetAttr(kUseNesterov);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(AdamInfer(primitive, input_args));
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Adam, prim::kPrimAdam, AdamInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameAdam, Adam);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,29 +14,34 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_ADAM_H_
|
||||
#define MINDSPORE_CORE_C_OPS_ADAM_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_ADAM_H_
|
||||
#define MINDSPORE_CORE_OPS_ADAM_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAdam = "Adam";
|
||||
class Adam : public PrimitiveC {
|
||||
public:
|
||||
Adam() : PrimitiveC(kNameAdam) {}
|
||||
~Adam() = default;
|
||||
MS_DECLARE_PARENT(Adam, PrimitiveC);
|
||||
void Init(const bool &use_locking = false, const bool &use_nesteroy = false);
|
||||
void set_use_locking(const bool &use_locking);
|
||||
void set_use_nesteroy(const bool &use_nesteroy);
|
||||
void Init(const bool use_locking = false, const bool use_nesterov = false);
|
||||
void set_use_locking(const bool use_locking);
|
||||
void set_use_nesterov(const bool use_nesterov);
|
||||
bool get_use_locking() const;
|
||||
bool get_use_nesteroy() const;
|
||||
bool get_use_nesterov() const;
|
||||
};
|
||||
AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAdamPtr = std::shared_ptr<Adam>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_ADAM_H_
|
||||
#endif // MINDSPORE_CORE_OPS_ADAM_H_
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/add.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto add_prim = primitive->cast<PrimAddPtr>();
|
||||
MS_EXCEPTION_IF_NULL(add_prim);
|
||||
auto prim_name = add_prim->name();
|
||||
return BroadCastInferShape(prim_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return TypeIdToType(infer_type);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimAdd, AddInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameAdd, Add);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,21 +14,23 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_ADD_H_
|
||||
#define MINDSPORE_CORE_C_OPS_ADD_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_ADD_H_
|
||||
#define MINDSPORE_CORE_OPS_ADD_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAdd = "Add";
|
||||
class Add : public PrimitiveC {
|
||||
public:
|
||||
Add() : PrimitiveC(kNameAdd) { InitIOName({"x", "y"}, {"output"}); }
|
||||
explicit Add(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x", "y"}, {"output"}); }
|
||||
~Add() = default;
|
||||
MS_DECLARE_PARENT(Add, PrimitiveC);
|
||||
void Init() {}
|
||||
|
@ -37,6 +39,7 @@ class Add : public PrimitiveC {
|
|||
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAddPtr = std::shared_ptr<Add>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_ADD_H_
|
||||
#endif // MINDSPORE_CORE_OPS_ADD_H_
|
|
@ -14,8 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/add_fold.h"
|
||||
#include "ops/add_fold.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
REGISTER_PRIMITIVE_C(kNameAddFold, AddFold);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,17 +14,18 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_ADDFOLD_H_
|
||||
#define MINDSPORE_CORE_C_OPS_ADDFOLD_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_ADD_FOLD_H_
|
||||
#define MINDSPORE_CORE_OPS_ADD_FOLD_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAddFold = "AddFold";
|
||||
class AddFold : public PrimitiveC {
|
||||
public:
|
||||
|
@ -33,6 +34,7 @@ class AddFold : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(AddFold, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_ADDFOLD_H_
|
||||
#endif // MINDSPORE_CORE_OPS_ADD_FOLD_H_
|
|
@ -0,0 +1,108 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/adder.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void Adder::Init(const int64_t in_channel, const int64_t out_channel, const std::vector<int64_t> &kernel_size,
|
||||
const PadMode &pad_mode, const std::vector<int64_t> &stride, const std::vector<int64_t> &pad_list,
|
||||
const std::vector<int64_t> &dilation, const int64_t group, const Format &format) {
|
||||
set_in_channel(in_channel);
|
||||
set_out_channel(out_channel);
|
||||
set_kernel_size(kernel_size);
|
||||
set_pad_mode(pad_mode);
|
||||
set_stride(stride);
|
||||
set_pad_list(pad_list);
|
||||
set_dilation(dilation);
|
||||
set_group(group);
|
||||
set_format(format);
|
||||
}
|
||||
|
||||
void Adder::set_in_channel(const int64_t in_channel) { this->AddAttr(kInChannel, MakeValue(in_channel)); }
|
||||
|
||||
int64_t Adder::get_in_channel() const {
|
||||
auto value_ptr = GetAttr(kInChannel);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void Adder::set_out_channel(const int64_t out_channel) { this->AddAttr(kOutChannel, MakeValue(out_channel)); }
|
||||
|
||||
int64_t Adder::get_out_channel() const {
|
||||
auto value_ptr = GetAttr(kOutChannel);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void Adder::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
this->AddAttr(kKernelSize, MakeValue(kernel_size));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Adder::get_kernel_size() const {
|
||||
auto value_ptr = GetAttr(kKernelSize);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void Adder::set_pad_mode(const PadMode &pad_mode) {
|
||||
int64_t swi = pad_mode;
|
||||
this->AddAttr(kPadMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
PadMode Adder::get_pad_mode() const {
|
||||
auto value_ptr = GetAttr(kPadMode);
|
||||
return PadMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
void Adder::set_stride(const std::vector<int64_t> &stride) { this->AddAttr(kStride, MakeValue(stride)); }
|
||||
|
||||
std::vector<int64_t> Adder::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void Adder::set_pad_list(const std::vector<int64_t> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); }
|
||||
|
||||
std::vector<int64_t> Adder::get_pad_list() const {
|
||||
auto value_ptr = GetAttr(kPadList);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void Adder::set_dilation(const std::vector<int64_t> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); }
|
||||
|
||||
std::vector<int64_t> Adder::get_dilation() const {
|
||||
auto value_ptr = GetAttr(kDilation);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void Adder::set_group(const int64_t group) { this->AddAttr(kGroup, MakeValue(group)); }
|
||||
|
||||
int64_t Adder::get_group() const {
|
||||
auto value_ptr = GetAttr(kGroup);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void Adder::set_format(const Format &format) {
|
||||
int64_t swi = format;
|
||||
this->AddAttr(kFormat, MakeValue(swi));
|
||||
}
|
||||
|
||||
Format Adder::get_format() const {
|
||||
auto value_ptr = GetAttr(kFormat);
|
||||
return Format(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAdder, Adder);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_ADDER_H_
|
||||
#define MINDSPORE_CORE_OPS_ADDER_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAdder = "Adder";
|
||||
class Adder : public PrimitiveC {
|
||||
public:
|
||||
explicit Adder(const std::string &k_name = kNameAdder) : PrimitiveC(k_name) {}
|
||||
~Adder() = default;
|
||||
MS_DECLARE_PARENT(Adder, PrimitiveC);
|
||||
void Init(const int64_t in_channel, const int64_t out_channel, const std::vector<int64_t> &kernel_size,
|
||||
const PadMode &pad_mode, const std::vector<int64_t> &stride, const std::vector<int64_t> &pad_list,
|
||||
const std::vector<int64_t> &dilation, const int64_t group, const Format &format);
|
||||
void set_in_channel(const int64_t in_channel);
|
||||
void set_out_channel(const int64_t out_channel);
|
||||
void set_kernel_size(const std::vector<int64_t> &kernel_size);
|
||||
void set_pad_mode(const PadMode &pad_mode);
|
||||
void set_stride(const std::vector<int64_t> &stride);
|
||||
void set_pad_list(const std::vector<int64_t> &pad_list);
|
||||
void set_dilation(const std::vector<int64_t> &dilation);
|
||||
void set_group(const int64_t group);
|
||||
void set_format(const Format &format);
|
||||
|
||||
int64_t get_in_channel() const;
|
||||
int64_t get_out_channel() const;
|
||||
std::vector<int64_t> get_kernel_size() const;
|
||||
PadMode get_pad_mode() const;
|
||||
std::vector<int64_t> get_stride() const;
|
||||
std::vector<int64_t> get_pad_list() const;
|
||||
std::vector<int64_t> get_dilation() const;
|
||||
int64_t get_group() const;
|
||||
Format get_format() const;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_ADDER_H_
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "ops/addn.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto input_tuple = input_args[0]->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_tuple);
|
||||
auto elements = input_tuple->elements();
|
||||
CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name);
|
||||
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(element0);
|
||||
auto element0_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name);
|
||||
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("element0", element0->BuildType());
|
||||
for (size_t i = 1; i < elements.size(); ++i) {
|
||||
std::string elementi = "element" + std::to_string(i);
|
||||
auto elementi_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name);
|
||||
CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(),
|
||||
prim_name);
|
||||
for (size_t j = 0; j < element0_shape.size(); ++j) {
|
||||
if (elementi_shape[j] != element0_shape[j]) {
|
||||
MS_LOG(EXCEPTION) << "element " << i << " shape in input can not concat with first element.";
|
||||
}
|
||||
}
|
||||
types.emplace(elementi, elements[i]->BuildType());
|
||||
}
|
||||
std::set<TypeId> valid_types = common_valid_types;
|
||||
valid_types.insert(kNumberTypeBool);
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
|
||||
|
||||
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(infer_type),
|
||||
std::make_shared<abstract::Shape>(element0_shape));
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AddN, prim::kPrimAddN, AddNInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameAddN, AddN);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,13 +14,16 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_ADDN_H_
|
||||
#define MINDSPORE_CORE_C_OPS_ADDN_H_
|
||||
#include "c_ops/primitive_c.h"
|
||||
#ifndef MINDSPORE_CORE_OPS_ADDN_H_
|
||||
#define MINDSPORE_CORE_OPS_ADDN_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAddN = "AddN";
|
||||
class AddN : public PrimitiveC {
|
||||
public:
|
||||
|
@ -29,6 +32,10 @@ class AddN : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(AddN, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAddNPtr = std::shared_ptr<AddN>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_ADDN_H_
|
||||
#endif // MINDSPORE_CORE_OPS_ADDN_H_
|
|
@ -14,19 +14,20 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/concat.h"
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "ops/all.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
namespace mindspore {
|
||||
|
||||
void Concat::Init(int64_t axis) { this->set_axis(axis); }
|
||||
int64_t Concat::get_axis() const {
|
||||
auto value_ptr = this->GetAttr(kAxis);
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void All::Init(const int64_t keep_dims) { this->set_keep_dims(keep_dims); }
|
||||
|
||||
void All::set_keep_dims(const int64_t keep_dims) { this->AddAttr(kKeepDims, MakeValue(keep_dims)); }
|
||||
|
||||
int64_t All::get_keep_dims() const {
|
||||
auto value_ptr = GetAttr(kKeepDims);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void Concat::set_axis(int64_t axis) {
|
||||
this->AddAttr(kAxis, MakeValue(CheckAndConvertUtils::CheckInteger(kAxis, axis, kGreaterEqual, 0, this->name())));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameConcat, Concat);
|
||||
REGISTER_PRIMITIVE_C(kNameAll, All);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_ALL_H_
|
||||
#define MINDSPORE_CORE_OPS_ALL_H_
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAll = "All";
|
||||
class All : public PrimitiveC {
|
||||
public:
|
||||
All() : PrimitiveC(kNameAll) {}
|
||||
~All() = default;
|
||||
MS_DECLARE_PARENT(All, PrimitiveC);
|
||||
void Init(const int64_t keep_dims);
|
||||
void set_keep_dims(const int64_t keep_dims);
|
||||
int64_t get_keep_dims() const;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_ALL_H_
|
|
@ -0,0 +1,89 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "ops/apply_momentum.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void ApplyMomentum::Init(const bool use_nesterov, const bool use_locking, const float gradient_scale) {
|
||||
this->set_use_nesterov(use_nesterov);
|
||||
this->set_use_locking(use_locking);
|
||||
this->set_gradient_scale(gradient_scale);
|
||||
}
|
||||
|
||||
void ApplyMomentum::set_use_nesterov(const bool use_nesterov) { this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); }
|
||||
|
||||
void ApplyMomentum::set_use_locking(const bool use_locking) { this->AddAttr(kUseLocking, MakeValue(use_locking)); }
|
||||
|
||||
void ApplyMomentum::set_gradient_scale(const float gradient_scale) {
|
||||
this->AddAttr(kGradientScale, MakeValue(gradient_scale));
|
||||
}
|
||||
|
||||
bool ApplyMomentum::get_use_nesterov() const {
|
||||
auto value_ptr = GetAttr(kUseNesterov);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
bool ApplyMomentum::get_use_locking() const {
|
||||
auto value_ptr = GetAttr(kUseLocking);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
float ApplyMomentum::get_gradient_scale() const {
|
||||
auto value_ptr = GetAttr(kGradientScale);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto momentum_prim = primitive->cast<PrimApplyMomentumPtr>();
|
||||
MS_EXCEPTION_IF_NULL(momentum_prim);
|
||||
auto prim_name = momentum_prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name);
|
||||
|
||||
// Infer type
|
||||
auto v_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto a_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto l_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto g_type = input_args[3]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto m_type = input_args[4]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, valid_types, prim_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_type, valid_types, prim_name);
|
||||
const std::set<TypePtr> valid_types_ptr = {TypeIdToType(kNumberTypeFloat16), TypeIdToType(kNumberTypeFloat32),
|
||||
TypeIdToType(kNumberTypeFloat64)};
|
||||
std::map<std::string, TypePtr> args;
|
||||
args.insert({"l_type", l_type});
|
||||
args.insert({"g_type", g_type});
|
||||
args.insert({"m_type", m_type});
|
||||
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types_ptr, prim_name);
|
||||
|
||||
return std::make_shared<abstract::AbstractTensor>(g_type, v_shape);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyMomentum, prim::kPrimApplyMomentum, ApplyMomentumInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,13 +14,17 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_APPLYMOMENTUM_H_
|
||||
#define MINDSPORE_CORE_C_OPS_APPLYMOMENTUM_H_
|
||||
#include "c_ops/primitive_c.h"
|
||||
#ifndef MINDSPORE_CORE_OPS_APPLY_MOMENTUM_H_
|
||||
#define MINDSPORE_CORE_OPS_APPLY_MOMENTUM_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyMomentum = "ApplyMomentum";
|
||||
class ApplyMomentum : public PrimitiveC {
|
||||
public:
|
||||
|
@ -29,14 +33,18 @@ class ApplyMomentum : public PrimitiveC {
|
|||
}
|
||||
~ApplyMomentum() = default;
|
||||
MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC);
|
||||
void Init(bool use_nesterov, bool use_locking, float gradient_scale);
|
||||
void set_use_nesterov(bool use_nesterov);
|
||||
void set_use_locking(bool use_locking);
|
||||
void set_gradient_scale(float gradient_scale);
|
||||
void Init(const bool use_nesterov = false, const bool use_locking = false, const float gradient_scale = 1.0);
|
||||
void set_use_nesterov(const bool use_nesterov);
|
||||
void set_use_locking(const bool use_locking);
|
||||
void set_gradient_scale(const float gradient_scale);
|
||||
bool get_use_nesterov() const;
|
||||
bool get_use_locking() const;
|
||||
float get_gradient_scale();
|
||||
float get_gradient_scale() const;
|
||||
};
|
||||
AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimApplyMomentumPtr = std::shared_ptr<ApplyMomentum>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_APPLYMOMENTUM_H_
|
||||
#endif // MINDSPORE_CORE_OPS_APPLY_MOMENTUM_H_
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/arg_max.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto prim = primitive->cast<PrimArgMaxPtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto axis = prim->get_axis();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_rank = SizeToLong(x_shape.size());
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name);
|
||||
axis = axis < 0 ? axis + x_rank : axis;
|
||||
std::vector<int64_t> out_shape;
|
||||
for (size_t i = 0; i < x_shape.size(); ++i) {
|
||||
if (SizeToLong(i) != axis) {
|
||||
out_shape.emplace_back(x_shape[i]);
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return kInt32;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ArgMax::Init(const int64_t axis, const TypeId output_type) {
|
||||
set_axis(axis);
|
||||
set_output_type(output_type);
|
||||
}
|
||||
|
||||
void ArgMax::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); }
|
||||
void ArgMax::set_output_type(const TypeId output_type) { this->AddAttr(kOutputType, TypeIdToType(output_type)); }
|
||||
|
||||
int64_t ArgMax::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
||||
TypeId ArgMax::get_output_type() const {
|
||||
auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element();
|
||||
return type_ptr->type_id();
|
||||
}
|
||||
|
||||
AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ArgMax, prim::kPrimArgMax, ArgMaxInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameArgMax, ArgMax);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_ARG_MAX_H_
|
||||
#define MINDSPORE_CORE_OPS_ARG_MAX_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameArgMax = "Argmax";
|
||||
class ArgMax : public PrimitiveC {
|
||||
public:
|
||||
ArgMax() : PrimitiveC(kNameArgMax) { InitIOName({"x"}, {"output"}); }
|
||||
explicit ArgMax(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
|
||||
~ArgMax() = default;
|
||||
MS_DECLARE_PARENT(ArgMax, PrimitiveC);
|
||||
void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32);
|
||||
void set_axis(const int64_t axis);
|
||||
void set_output_type(const TypeId output_type);
|
||||
|
||||
int64_t get_axis() const;
|
||||
TypeId get_output_type() const;
|
||||
};
|
||||
AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimArgMaxPtr = std::shared_ptr<ArgMax>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_ARG_MAX_H_
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include "ops/arg_min.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void ArgMin::Init(const int64_t axis, const TypeId output_type) {
|
||||
set_axis(axis);
|
||||
set_output_type(output_type);
|
||||
}
|
||||
|
||||
void ArgMin::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); }
|
||||
void ArgMin::set_output_type(const TypeId output_type) { this->AddAttr(kOutputType, TypeIdToType(output_type)); }
|
||||
|
||||
int64_t ArgMin::get_axis() const {
|
||||
auto value_ptr = GetAttr(kAxis);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
TypeId ArgMin::get_output_type() const {
|
||||
auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element();
|
||||
return type_ptr->type_id();
|
||||
}
|
||||
|
||||
AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto argmin_prim = primitive->cast<PrimArgMin>();
|
||||
MS_EXCEPTION_IF_NULL(argmin_prim);
|
||||
auto prim_name = argmin_prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("arg_min_infer", input_args.size(), kEqual, 1, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto axis = argmin_prim->get_axis();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto x_rank = SizeToLong(x_shape.size());
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name);
|
||||
if (axis < 0) {
|
||||
axis += x_rank;
|
||||
}
|
||||
std::vector<int64_t> out_shape;
|
||||
for (int64_t i = 0; i < x_rank; i++) {
|
||||
if (i != axis) {
|
||||
out_shape.push_back(x_shape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Infer type
|
||||
auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)};
|
||||
CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim_name);
|
||||
|
||||
return std::make_shared<abstract::AbstractTensor>(x_dtype, std::make_shared<abstract::Shape>(out_shape));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ArgMin, prim::kPrimArgMin, ArgMinInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameArgMin, ArgMin);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,32 +14,37 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_ARGMIN_H_
|
||||
#define MINDSPORE_CORE_C_OPS_ARGMIN_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_ARG_MIN_H_
|
||||
#define MINDSPORE_CORE_OPS_ARG_MIN_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameArgMin = "ArgMin";
|
||||
class ArgMin : public PrimitiveC {
|
||||
public:
|
||||
ArgMin() : PrimitiveC(kNameArgMin) { InitIOName({"x"}, {"output"}); }
|
||||
explicit ArgMin(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
|
||||
~ArgMin() = default;
|
||||
MS_DECLARE_PARENT(ArgMin, PrimitiveC);
|
||||
void Init(bool keep_dims, int64_t axis = -1);
|
||||
void set_axis(int64_t axis);
|
||||
void set_keep_dims(bool keep_dims);
|
||||
int64_t get_axis();
|
||||
bool get_keep_dims();
|
||||
void Init(const int64_t axis = -1, const TypeId output_type = kNumberTypeInt32);
|
||||
void set_axis(const int64_t axis);
|
||||
void set_output_type(const TypeId output_type);
|
||||
|
||||
int64_t get_axis() const;
|
||||
TypeId get_output_type() const;
|
||||
};
|
||||
AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimArgMin = std::shared_ptr<ArgMin>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_ARGMIN_H_
|
||||
#endif // MINDSPORE_CORE_OPS_ARG_MIN_H_
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/asin.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto asin_prim = primitive->cast<PrimAsinPtr>();
|
||||
MS_EXCEPTION_IF_NULL(asin_prim);
|
||||
auto prim_name = asin_prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name);
|
||||
|
||||
// Infer Shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto infer_shape = std::make_shared<abstract::Shape>(x_shape);
|
||||
|
||||
// Infer Type
|
||||
auto dtype = input_args[0]->BuildType();
|
||||
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name);
|
||||
auto tensor_type = dtype->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto element = tensor_type->element();
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
auto infer_type = std::make_shared<TensorType>(TypeIdToType(element->type_id()));
|
||||
|
||||
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Asin, prim::kPrimAsin, AsinInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameAsin, Asin);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_ASIN_H_
|
||||
#define MINDSPORE_CORE_OPS_ASIN_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAsin = "Asin";
|
||||
class Asin : public PrimitiveC {
|
||||
public:
|
||||
Asin() : PrimitiveC(kNameAsin) {}
|
||||
~Asin() = default;
|
||||
MS_DECLARE_PARENT(Asin, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
AbstractBasePtr ASinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAsinPtr = std::shared_ptr<Asin>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_ASIN_H_
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/assert.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void Assert::Init(const int64_t summarize) { set_summarize(summarize); }
|
||||
|
||||
void Assert::set_summarize(const int64_t summarize) { this->AddAttr(kSummarize, MakeValue(summarize)); }
|
||||
|
||||
int64_t Assert::get_summarize() const {
|
||||
auto value_ptr = GetAttr(kSummarize);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto Assert_prim = primitive->cast<PrimAssertPtr>();
|
||||
MS_EXCEPTION_IF_NULL(Assert_prim);
|
||||
auto op_name = Assert_prim->name();
|
||||
TypePtr condition;
|
||||
if (!(input_args[0]->BuildType()->type_id() == kObjectTypeTensorType)) {
|
||||
auto condition_value = GetValue<std::vector<bool>>(input_args[0]->BuildValue());
|
||||
CheckAndConvertUtils::CheckInteger("condition's rank", condition_value.size(), kLessEqual, 1, op_name);
|
||||
if (condition_value.size() == 1) {
|
||||
CheckAndConvertUtils::CheckInteger("condition[0]", condition_value[0], kEqual, 1, op_name);
|
||||
}
|
||||
condition = TypeIdToType(kNumberTypeBool);
|
||||
} else {
|
||||
auto condition_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
|
||||
CheckAndConvertUtils::CheckInteger("condition's rank", condition_shape[0], kLessEqual, 1, op_name);
|
||||
if (condition_shape[0] == 1) {
|
||||
auto condition_value = reinterpret_cast<bool *>(input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c());
|
||||
MS_EXCEPTION_IF_NULL(condition_value);
|
||||
// auto condition_value = GetValue<bool>(input_args[0]->BuildValue());
|
||||
CheckAndConvertUtils::CheckInteger("condition[0]", *condition_value, kEqual, 1, op_name);
|
||||
}
|
||||
condition = input_args[0]->BuildType();
|
||||
}
|
||||
std::vector<int64_t> output_shape = {1};
|
||||
std::set<TypePtr> local_bool = {TypeIdToType(kNumberTypeBool)};
|
||||
std::map<std::string, TypePtr> args = {{"condition", condition}};
|
||||
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name);
|
||||
auto inputs_type = input_args[1]->BuildType()->cast<TuplePtr>()->elements();
|
||||
for (auto dtype : inputs_type) {
|
||||
std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)};
|
||||
CheckAndConvertUtils::CheckSubClass("input", dtype, template_types, op_name);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(kNumberTypeInt32), output_shape);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Assert, prim::kPrimAssert, AssertInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameAssert, Assert);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_ASSERT_H_
|
||||
#define MINDSPORE_CORE_OPS_ASSERT_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAssert = "Assert";
|
||||
class Assert : public PrimitiveC {
|
||||
public:
|
||||
Assert() : PrimitiveC(kNameAssert) {}
|
||||
~Assert() = default;
|
||||
MS_DECLARE_PARENT(Assert, PrimitiveC);
|
||||
void Init(const int64_t summarize = 3);
|
||||
void set_summarize(const int64_t summarize);
|
||||
int64_t get_summarize() const;
|
||||
};
|
||||
|
||||
AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAssertPtr = std::shared_ptr<Assert>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_ASSERT_H_
|
|
@ -14,8 +14,18 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/assign.h"
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "ops/assign.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ir/dtype/ref.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
REGISTER_PRIMITIVE_C(kNameAssign, Assign);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,13 +14,17 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_ASSIGN_H_
|
||||
#define MINDSPORE_CORE_C_OPS_ASSIGN_H_
|
||||
#include "c_ops/primitive_c.h"
|
||||
#ifndef MINDSPORE_CORE_OPS_ASSIGN_H_
|
||||
#define MINDSPORE_CORE_OPS_ASSIGN_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAssign = "Assign";
|
||||
class Assign : public PrimitiveC {
|
||||
public:
|
||||
|
@ -29,6 +33,9 @@ class Assign : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(Assign, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
|
||||
using PrimAssignPtr = std::shared_ptr<Assign>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_ASSIGN_H_
|
||||
#endif // MINDSPORE_CORE_OPS_ASSIGN_H_
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "ops/assign_add.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto assignadd_prim = primitive->cast<PrimAssignAddPtr>();
|
||||
MS_EXCEPTION_IF_NULL(assignadd_prim);
|
||||
auto prim_name = assignadd_prim->name();
|
||||
auto value_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("value_shape", input_args[1]->BuildShape(), prim_name);
|
||||
return std::make_shared<abstract::Shape>(value_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("w", input_args[1]->BuildType());
|
||||
// check_scalar_or_tensor_types_same
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd");
|
||||
return TypeIdToType(infer_type);
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AssignAdd, prim::kPrimAssignAdd, AssignAddInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameAssignAdd, AssignAdd);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,13 +14,17 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_ASSIGNADD_H_
|
||||
#define MINDSPORE_CORE_C_OPS_ASSIGNADD_H_
|
||||
#include "c_ops/primitive_c.h"
|
||||
#ifndef MINDSPORE_CORE_OPS_ASSIGN_ADD_H_
|
||||
#define MINDSPORE_CORE_OPS_ASSIGN_ADD_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAssignAdd = "AssignAdd";
|
||||
class AssignAdd : public PrimitiveC {
|
||||
public:
|
||||
|
@ -29,6 +33,10 @@ class AssignAdd : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(AssignAdd, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAssignAddPtr = std::shared_ptr<AssignAdd>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_ASSIGNADD_H_
|
||||
#endif // MINDSPORE_CORE_OPS_ASSIGN_ADD_H_
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "ops/atan.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto atan_prim = primitive->cast<PrimAtanPtr>();
|
||||
MS_EXCEPTION_IF_NULL(atan_prim);
|
||||
auto prim_name = atan_prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name);
|
||||
|
||||
// Infer Shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto infer_shape = std::make_shared<abstract::Shape>(x_shape);
|
||||
|
||||
// Infer Type
|
||||
auto dtype = input_args[0]->BuildType();
|
||||
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name);
|
||||
auto tensor_type = dtype->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto element = tensor_type->element();
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
auto infer_type = std::make_shared<TensorType>(TypeIdToType(element->type_id()));
|
||||
|
||||
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Atan, prim::kPrimAtan, AtanInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameAtan, Atan);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,17 +14,18 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_ATAN_H_
|
||||
#define MINDSPORE_CORE_C_OPS_ATAN_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_ATAN_H_
|
||||
#define MINDSPORE_CORE_OPS_ATAN_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAtan = "Atan";
|
||||
class Atan : public PrimitiveC {
|
||||
public:
|
||||
|
@ -33,6 +34,10 @@ class Atan : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(Atan, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
AbstractBasePtr ATanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAtanPtr = std::shared_ptr<Atan>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_ATAN_H_
|
||||
#endif // MINDSPORE_CORE_OPS_ATAN_H_
|
|
@ -0,0 +1,125 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/audio_spectrogram.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto audio_spectrogram_prim = primitive->cast<PrimAudioSpectrogramPtr>();
|
||||
MS_EXCEPTION_IF_NULL(audio_spectrogram_prim);
|
||||
auto prim_name = audio_spectrogram_prim->name();
|
||||
auto input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
|
||||
if (input_shape.size() != 2) {
|
||||
MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions";
|
||||
}
|
||||
if (audio_spectrogram_prim->get_window_size() < 2) {
|
||||
MS_LOG(ERROR) << "window size is too short, now is " << audio_spectrogram_prim->get_window_size();
|
||||
}
|
||||
if (audio_spectrogram_prim->get_stride() < 1) {
|
||||
MS_LOG(ERROR) << "stride must be positive, now is " << audio_spectrogram_prim->get_stride();
|
||||
}
|
||||
std::vector<int64_t> infer_shape;
|
||||
infer_shape.push_back(input_shape[1]);
|
||||
int64_t sample_sub_window = input_shape[0] - audio_spectrogram_prim->get_window_size();
|
||||
infer_shape.push_back(sample_sub_window < 0 ? 0 : 1 + sample_sub_window / audio_spectrogram_prim->get_stride());
|
||||
int64_t fft_length = audio_spectrogram_prim->GetFftLength(audio_spectrogram_prim->get_window_size());
|
||||
infer_shape.push_back(fft_length / 2 + 1);
|
||||
MS_LOG(ERROR) << infer_shape;
|
||||
return std::make_shared<abstract::Shape>(infer_shape);
|
||||
}
|
||||
|
||||
TypePtr AudioSpectrogramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto infer_type = input_args[0]->BuildType();
|
||||
auto tensor_type = infer_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto data_type = tensor_type->element();
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
return data_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void AudioSpectrogram::set_window_size(const int64_t window_size) {
|
||||
this->AddAttr(kWindowSize, MakeValue(window_size));
|
||||
}
|
||||
int64_t AudioSpectrogram::get_window_size() const {
|
||||
auto value_ptr = GetAttr(kWindowSize);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void AudioSpectrogram::set_stride(const int64_t stride) { this->AddAttr(kStride, MakeValue(stride)); }
|
||||
int64_t AudioSpectrogram::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t AudioSpectrogram::Log2Ceil(int64_t length) {
|
||||
if (length == 0) {
|
||||
return -1;
|
||||
}
|
||||
int64_t floor = 0;
|
||||
for (int64_t i = 4; i >= 0; --i) {
|
||||
const int64_t shift = (int64_t)(1 << i);
|
||||
int64_t tmp = length >> shift;
|
||||
if (tmp != 0) {
|
||||
length = tmp;
|
||||
floor += shift;
|
||||
}
|
||||
}
|
||||
return length == (length & ~(length - 1)) ? floor : floor + 1;
|
||||
}
|
||||
|
||||
int64_t AudioSpectrogram::GetFftLength(int64_t length) {
|
||||
int64_t shift = Log2Ceil(length);
|
||||
return 1 << shift;
|
||||
}
|
||||
|
||||
void AudioSpectrogram::set_mag_square(const bool mag_square) { this->AddAttr(kMagSquare, MakeValue(mag_square)); }
|
||||
bool AudioSpectrogram::get_mag_square() const {
|
||||
auto value_ptr = GetAttr(kMagSquare);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
void AudioSpectrogram::Init(const int64_t window_size, const int64_t stride, const bool mag_square) {
|
||||
this->set_window_size(window_size);
|
||||
this->set_stride(stride);
|
||||
this->set_mag_square(mag_square);
|
||||
}
|
||||
|
||||
AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(AudioSpectrogramInferType(primitive, input_args),
|
||||
AudioSpectrogramInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AudioSpectrogram, prim::kPrimAudioSpectrogram, AudioSpectrogramInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameAudioSpectrogram, AudioSpectrogram);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,31 +14,38 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_AUDIOSPECTROGRAM_H_
|
||||
#define MINDSPORE_CORE_C_OPS_AUDIOSPECTROGRAM_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_AUDIO_SPECTROGRAM_H_
|
||||
#define MINDSPORE_CORE_OPS_AUDIO_SPECTROGRAM_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAudioSpectrogram = "AudioSpectrogram";
|
||||
class AudioSpectrogram : public PrimitiveC {
|
||||
public:
|
||||
AudioSpectrogram() : PrimitiveC(kNameAudioSpectrogram) {}
|
||||
~AudioSpectrogram() = default;
|
||||
MS_DECLARE_PARENT(AudioSpectrogram, PrimitiveC);
|
||||
void Init(const int64_t &window_size, const int64_t &stride, const bool &mag_square);
|
||||
void set_window_size(const int64_t &window_size);
|
||||
void set_stride(const int64_t &stride);
|
||||
void set_mag_square(const bool &mag_square);
|
||||
void Init(const int64_t window_size, const int64_t stride, const bool mag_square);
|
||||
void set_window_size(const int64_t window_size);
|
||||
void set_stride(const int64_t stride);
|
||||
void set_mag_square(const bool mag_square);
|
||||
int64_t get_window_size() const;
|
||||
int64_t get_stride() const;
|
||||
bool get_mag_square() const;
|
||||
int64_t Log2Ceil(int64_t length);
|
||||
int64_t GetFftLength(int64_t length);
|
||||
};
|
||||
AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAudioSpectrogramPtr = std::shared_ptr<AudioSpectrogram>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_AUDIOSPECTROGRAM_H_
|
||||
#endif // MINDSPORE_CORE_OPS_AUDIO_SPECTROGRAM_H_
|
|
@ -14,25 +14,26 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/avg_pool.h"
|
||||
#include "ops/avg_pool.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
void AvgPool::set_padding(const std::string &padding) {
|
||||
CheckAndConvertUtils::CheckString(kPadding, padding, {kValid, kSame}, this->name());
|
||||
this->AddAttr(kPadding, MakeValue(padding));
|
||||
namespace ops {
|
||||
void AvgPool::set_pad_mode(const PadMode &pad_mode) {
|
||||
int64_t swi = pad_mode;
|
||||
this->AddAttr(kPadMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
std::string AvgPool::get_padding() const {
|
||||
auto value_ptr = GetAttr(kPadding);
|
||||
return GetValue<std::string>(value_ptr);
|
||||
PadMode AvgPool::get_pad_mode() const {
|
||||
auto value_ptr = GetAttr(kPadMode);
|
||||
return PadMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(),
|
||||
|
@ -44,12 +45,12 @@ std::vector<int64_t> AvgPool::get_kernel_size() const {
|
|||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
void AvgPool::set_strides(const std::vector<int64_t> &strides) {
|
||||
this->AddAttr(kStride,
|
||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, strides, this->name(), false, true)));
|
||||
this->AddAttr(kStrides,
|
||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name(), false, true)));
|
||||
}
|
||||
|
||||
std::vector<int64_t> AvgPool::get_strides() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
auto value_ptr = GetAttr(kStrides);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
|
@ -70,20 +71,19 @@ std::vector<int64_t> AvgPool::get_pad() const {
|
|||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void AvgPool::set_round_mode(const int64_t &round_mode) {
|
||||
CheckAndConvertUtils::CheckInRange(kRoundMode, round_mode, kIncludeBoth, {0, 1}, this->name());
|
||||
this->AddAttr(kRoundMode, MakeValue(round_mode));
|
||||
void AvgPool::set_round_mode(const RoundMode &round_mode) {
|
||||
int64_t swi = round_mode;
|
||||
this->AddAttr(kRoundMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
int64_t AvgPool::get_round_mode() const {
|
||||
RoundMode AvgPool::get_round_mode() const {
|
||||
auto value_ptr = GetAttr(kRoundMode);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
return RoundMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride,
|
||||
const std::string &padding, const Format &format, const std::vector<int64_t> &pad,
|
||||
const int64_t &round_mode) {
|
||||
this->set_padding(padding);
|
||||
void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride, const PadMode &pad_mode,
|
||||
const Format &format, const std::vector<int64_t> &pad, const RoundMode &round_mode) {
|
||||
this->set_pad_mode(pad_mode);
|
||||
this->set_kernel_size(kernel_size);
|
||||
this->set_strides(stride);
|
||||
this->set_format(format);
|
||||
|
@ -98,9 +98,12 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION_IF_NULL(pool_prim);
|
||||
auto op_name = pool_prim->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
|
||||
if (pool_prim->get_format() == NHWC) {
|
||||
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name);
|
||||
auto kernel_size = pool_prim->get_kernel_size();
|
||||
auto pad_mode = pool_prim->get_padding();
|
||||
auto pad_mode = pool_prim->get_pad_mode();
|
||||
auto batch = in_shape[0];
|
||||
auto channel = in_shape[1];
|
||||
auto in_h = in_shape[2];
|
||||
|
@ -113,14 +116,17 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto stride_w = strides[3];
|
||||
int64_t out_h = -1;
|
||||
int64_t out_w = -1;
|
||||
if (pad_mode == "valid") {
|
||||
if (pad_mode == VALID) {
|
||||
out_h = ceil((in_h - (kernel_h - 1)) / stride_h);
|
||||
out_w = ceil((in_w - (kernel_w - 1)) / stride_w);
|
||||
} else if (pad_mode == "same") {
|
||||
} else if (pad_mode == SAME) {
|
||||
out_h = ceil(in_h / stride_h);
|
||||
out_w = ceil(in_w / stride_w);
|
||||
}
|
||||
std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
|
||||
if (pool_prim->get_format() == NHWC) {
|
||||
out_shape = {batch, out_h, out_w, channel};
|
||||
}
|
||||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
|
||||
MS_LOG(EXCEPTION) << "Kernel size is not valid.";
|
||||
}
|
||||
|
@ -142,4 +148,5 @@ AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool, prim::kPrimAvgPool, AvgPoolInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,45 +14,48 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_AVG_POOL_H_
|
||||
#define MINDSPORE_CORE_C_OPS_AVG_POOL_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_AVG_POOL_H_
|
||||
#define MINDSPORE_CORE_OPS_AVG_POOL_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAvgPool = "AvgPool";
|
||||
class AvgPool : public PrimitiveC {
|
||||
public:
|
||||
AvgPool() : PrimitiveC(kNameAvgPool) { InitIOName({"x"}, {"output"}); }
|
||||
explicit AvgPool(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"output"}); }
|
||||
~AvgPool() = default;
|
||||
MS_DECLARE_PARENT(AvgPool, PrimitiveC);
|
||||
void Init(const std::vector<int64_t> &kernel_size = {1}, const std::vector<int64_t> &stride = {1},
|
||||
const std::string &padding = "valid", const Format &format = NCHW,
|
||||
const std::vector<int64_t> &pad = {0, 0, 0, 0}, const int64_t &round_mode = 0);
|
||||
void set_padding(const std::string &padding);
|
||||
const PadMode &pad_mode = VALID, const Format &format = NCHW,
|
||||
const std::vector<int64_t> &pad = {0, 0, 0, 0}, const RoundMode &round_mode = FLOOR);
|
||||
void set_pad_mode(const PadMode &pad_mode);
|
||||
void set_kernel_size(const std::vector<int64_t> &kernel_size);
|
||||
void set_strides(const std::vector<int64_t> &strides);
|
||||
void set_format(const Format &format);
|
||||
void set_pad(const std::vector<int64_t> &pad);
|
||||
void set_round_mode(const int64_t &round_mode);
|
||||
void set_round_mode(const RoundMode &round_mode);
|
||||
|
||||
std::vector<int64_t> get_kernel_size() const;
|
||||
std::vector<int64_t> get_strides() const;
|
||||
std::string get_padding() const;
|
||||
PadMode get_pad_mode() const;
|
||||
Format get_format() const;
|
||||
std::vector<int64_t> get_pad() const;
|
||||
int64_t get_round_mode() const;
|
||||
RoundMode get_round_mode() const;
|
||||
};
|
||||
|
||||
AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAvgPoolPtr = std::shared_ptr<AvgPool>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_AVG_POOL_H_
|
||||
#endif // MINDSPORE_CORE_OPS_AVG_POOL_H_
|
|
@ -0,0 +1,140 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/batch_norm.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void BatchNorm::Init(const bool is_training, const float epsilon, const float momentum, const Format &format) {
|
||||
set_is_training(is_training);
|
||||
set_epsilon(epsilon);
|
||||
set_format(format);
|
||||
set_momentum(momentum);
|
||||
}
|
||||
|
||||
void BatchNorm::set_is_training(const bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); }
|
||||
|
||||
void BatchNorm::set_epsilon(const float epsilon) {
|
||||
CheckAndConvertUtils::CheckInRange<float>(kEpsilon, epsilon, kIncludeBoth, {0.0, 1.0}, this->name());
|
||||
this->AddAttr(kEpsilon, MakeValue(epsilon));
|
||||
}
|
||||
|
||||
void BatchNorm::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
this->AddAttr(kFormat, MakeValue(f));
|
||||
}
|
||||
|
||||
void BatchNorm::set_momentum(const float momentun) {
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>(kMomentum, momentun, kIncludeBoth, {0.0, 1.0}, this->name());
|
||||
this->AddAttr(kMomentum, MakeValue(momentun));
|
||||
}
|
||||
|
||||
float BatchNorm::get_momentum() const {
|
||||
auto value_ptr = GetAttr(kMomentum);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
bool BatchNorm::get_is_training() const {
|
||||
auto value_ptr = GetAttr(kIsTraining);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
float BatchNorm::get_epsilon() const {
|
||||
auto value_ptr = GetAttr(kEpsilon);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
Format BatchNorm::get_format() const {
|
||||
auto value_ptr = GetAttr(kFormat);
|
||||
return Format(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
// Infer shape
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto batch_prim = primitive->cast<PrimBatchNormPtr>();
|
||||
MS_EXCEPTION_IF_NULL(batch_prim);
|
||||
auto prim_name = batch_prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name);
|
||||
|
||||
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name);
|
||||
if (batch_prim->get_format() == NHWC) {
|
||||
input_x = {input_x[0], input_x[3], input_x[1], input_x[2]};
|
||||
}
|
||||
auto scale = CheckAndConvertUtils::ConvertShapePtrToShape("scale", input_args[1]->BuildShape(), prim_name);
|
||||
auto bias = CheckAndConvertUtils::ConvertShapePtrToShape("bias", input_args[2]->BuildShape(), prim_name);
|
||||
auto mean = CheckAndConvertUtils::ConvertShapePtrToShape("mean", input_args[3]->BuildShape(), prim_name);
|
||||
auto variance = CheckAndConvertUtils::ConvertShapePtrToShape("variance", input_args[4]->BuildShape(), prim_name);
|
||||
|
||||
std::vector<int64_t> input_shape_norm;
|
||||
if (batch_prim->get_format() == NCHW) {
|
||||
input_shape_norm =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||
} else {
|
||||
input_shape_norm.push_back(input_x[0]);
|
||||
input_shape_norm.push_back(input_x[3]);
|
||||
input_shape_norm.push_back(input_x[1]);
|
||||
input_shape_norm.push_back(input_x[2]);
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("scale rank", scale.size(), kEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError);
|
||||
CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name,
|
||||
TypeError);
|
||||
if (!batch_prim->get_is_training()) {
|
||||
CheckAndConvertUtils::CheckInteger("mean rank", mean.size(), kEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::Check("mean shape", mean, kEqual, "variance shape", variance, prim_name, TypeError);
|
||||
CheckAndConvertUtils::Check("mean shape", mean, kEqual, "scale shape", scale, prim_name, TypeError);
|
||||
}
|
||||
|
||||
// Infer type
|
||||
auto input_x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto scale_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto bias_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
|
||||
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
|
||||
std::map<std::string, TypePtr> args;
|
||||
args.emplace("scale", input_args[1]->BuildType());
|
||||
args.emplace("bias", input_args[2]->BuildType());
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
std::map<std::string, TypePtr> args_moving;
|
||||
args_moving.emplace("scale", input_args[2]->BuildType());
|
||||
args_moving.emplace("bias", input_args[3]->BuildType());
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name);
|
||||
|
||||
auto output0 = std::make_shared<abstract::AbstractTensor>(input_x_type, input_x);
|
||||
auto output1 = std::make_shared<abstract::AbstractTensor>(scale_type, scale);
|
||||
auto output2 = std::make_shared<abstract::AbstractTensor>(bias_type, scale);
|
||||
auto output3 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale);
|
||||
if (batch_prim->get_format() == NHWC) {
|
||||
output2 = std::make_shared<abstract::AbstractTensor>(scale_type, scale);
|
||||
output3 = std::make_shared<abstract::AbstractTensor>(bias_type, scale);
|
||||
output1 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale);
|
||||
}
|
||||
AbstractBasePtrList output = {output0, output1, output2, output3, output3};
|
||||
return std::make_shared<abstract::AbstractTuple>(output);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BatchNorm, prim::kPrimBatchNorm, BatchNormInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameBatchNorm, BatchNorm);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,17 +14,18 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_BATCH_NORMAL_H_
|
||||
#define MINDSPORE_CORE_C_OPS_BATCH_NORMAL_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_BATCH_NORMAL_H_
|
||||
#define MINDSPORE_CORE_OPS_BATCH_NORMAL_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameBatchNorm = "BatchNorm";
|
||||
class BatchNorm : public PrimitiveC {
|
||||
public:
|
||||
|
@ -34,19 +35,23 @@ class BatchNorm : public PrimitiveC {
|
|||
}
|
||||
~BatchNorm() = default;
|
||||
MS_DECLARE_PARENT(BatchNorm, PrimitiveC);
|
||||
void Init(bool is_training = false, float epsilon = 1e-5, const Format &format = NCHW);
|
||||
void set_is_training(bool is_training);
|
||||
void set_epsilon(float epsilon);
|
||||
void Init(const bool is_training = false, const float epsilon = 1e-5, const float momentun = 0.1,
|
||||
const Format &format = NCHW);
|
||||
void set_is_training(const bool is_training);
|
||||
void set_epsilon(const float epsilon);
|
||||
void set_format(const Format &format);
|
||||
bool get_is_trainging();
|
||||
float get_epsilon();
|
||||
void set_momentum(const float momentum);
|
||||
bool get_is_training() const;
|
||||
float get_epsilon() const;
|
||||
Format get_format() const;
|
||||
float get_momentum() const;
|
||||
};
|
||||
|
||||
AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimBatchNormPtr = std::shared_ptr<BatchNorm>;
|
||||
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_BatchNorm_H_
|
||||
#endif // MINDSPORE_CORE_OPS_BatchNorm_H_
|
|
@ -0,0 +1,116 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "ops/batch_norm_fold.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
||||
void BatchNormFold::Init(const float momentum, const float epsilon, const bool is_training, const int64_t freeze_bn) {
|
||||
set_momentum(momentum);
|
||||
set_epsilon(epsilon);
|
||||
set_is_training(is_training);
|
||||
set_freeze_bn(freeze_bn);
|
||||
}
|
||||
|
||||
void BatchNormFold::set_momentum(const float momentum) {
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>(kMomentum, momentum, kIncludeBoth, {0.0, 1.0}, this->name());
|
||||
this->AddAttr(kMomentum, MakeValue(momentum));
|
||||
}
|
||||
|
||||
float BatchNormFold::get_momentum() const {
|
||||
auto value_ptr = GetAttr(kMomentum);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
void BatchNormFold::set_epsilon(const float epsilon) {
|
||||
float match_value = 0.0;
|
||||
CheckAndConvertUtils::CheckValue(kEpsilon, epsilon, kGreaterThan, match_value, this->name());
|
||||
this->AddAttr(kEpsilon, MakeValue(epsilon));
|
||||
}
|
||||
|
||||
float BatchNormFold::get_epsilon() const {
|
||||
auto value_ptr = GetAttr(kEpsilon);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
void BatchNormFold::set_is_training(const bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); }
|
||||
|
||||
bool BatchNormFold::get_is_training() const {
|
||||
auto value_ptr = GetAttr(kIsTraining);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void BatchNormFold::set_freeze_bn(const int64_t freeze_bn) { this->AddAttr(kFreezeBn, MakeValue(freeze_bn)); }
|
||||
|
||||
int64_t BatchNormFold::get_freeze_bn() const {
|
||||
auto value_ptr = GetAttr(kFreezeBn);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto BatchNormFold_prim = primitive->cast<PrimBatchNormFoldPtr>();
|
||||
MS_EXCEPTION_IF_NULL(BatchNormFold_prim);
|
||||
auto op_name = BatchNormFold_prim->name();
|
||||
auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShape("mean_shape", input_args[1]->BuildShape(), op_name);
|
||||
auto variance_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("variance_shape", input_args[2]->BuildShape(), op_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
|
||||
auto global_step_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("global_step_shape", input_args[3]->BuildShape(), op_name);
|
||||
CheckAndConvertUtils::Check("mean_shape", mean_shape, kEqual, "gamma_shape", variance_shape, op_name);
|
||||
CheckAndConvertUtils::Check("mean_shape[0]", mean_shape[0], kEqual, "input channel", x_shape[1], op_name);
|
||||
CheckAndConvertUtils::CheckInteger("global step shape len", global_step_shape.size(), kEqual, 1, op_name);
|
||||
|
||||
auto mean_type = input_args[1]->BuildType();
|
||||
auto variance_type = input_args[2]->BuildType();
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
auto global_step_type = input_args[3]->BuildType();
|
||||
|
||||
std::map<std::string, TypePtr> args = {{"x", x_type}, {"mean", mean_type}, {"variance", variance_type}};
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(args, {kNumberTypeFloat16, kNumberTypeFloat32}, op_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("gloabal_step", global_step_type, {kNumberTypeInt32}, op_name);
|
||||
|
||||
auto tensor_type0 = x_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type0);
|
||||
auto element0 = tensor_type0->element();
|
||||
|
||||
auto tensor_type1 = mean_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type1);
|
||||
auto element1 = tensor_type1->element();
|
||||
|
||||
auto tensor_type2 = variance_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type2);
|
||||
auto element2 = tensor_type2->element();
|
||||
|
||||
CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "mean_type", element1->type_id(), op_name);
|
||||
CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "variance_type", element2->type_id(), op_name);
|
||||
|
||||
auto output = std::make_shared<abstract::AbstractTensor>(element0, mean_shape);
|
||||
AbstractBasePtrList output1 = {output, output, output, output};
|
||||
return std::make_shared<abstract::AbstractTuple>(output1);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BatchNormFold, prim::kPrimBatchNormFold, BatchNormFoldInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameBatchNormFold, BatchNormFold);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CORE_OPS_BATCH_NORM_FOLD_H_
|
||||
#define MINDSPORE_CORE_OPS_BATCH_NORM_FOLD_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameBatchNormFold = "BatchNormFold";
|
||||
class BatchNormFold : public PrimitiveC {
|
||||
public:
|
||||
BatchNormFold() : PrimitiveC(kNameBatchNormFold) {
|
||||
InitIOName({"x", "mean", "variance", "global_step"}, {"batch_mean", "batch_std", "running_mean", "running_std"});
|
||||
}
|
||||
~BatchNormFold() = default;
|
||||
MS_DECLARE_PARENT(BatchNormFold, PrimitiveC);
|
||||
void Init(const float momentum = 0.9, const float epsilon = 1e-5, const bool is_training = true,
|
||||
const int64_t freeze_bn = 0);
|
||||
void set_momentum(const float momentum);
|
||||
void set_epsilon(const float epsilon);
|
||||
void set_is_training(const bool is_training);
|
||||
void set_freeze_bn(const int64_t freeze_bn);
|
||||
|
||||
float get_momentum() const;
|
||||
float get_epsilon() const;
|
||||
bool get_is_training() const;
|
||||
int64_t get_freeze_bn() const;
|
||||
};
|
||||
|
||||
AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimBatchNormFoldPtr = std::shared_ptr<BatchNormFold>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_BATCH_NORM_FOLD_H_
|
|
@ -0,0 +1,81 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/batch_to_space.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void BatchToSpace::Init(const std::vector<int64_t> &block_size, const std::vector<std::vector<int64_t>> &crops) {
|
||||
this->set_block_size(block_size);
|
||||
this->set_crops(crops);
|
||||
}
|
||||
|
||||
void BatchToSpace::set_block_size(const std::vector<int64_t> &block_size) {
|
||||
this->AddAttr(kBlockSize, MakeValue(block_size));
|
||||
}
|
||||
|
||||
std::vector<int64_t> BatchToSpace::get_block_size() const {
|
||||
auto value_ptr = this->GetAttr(kBlockSize);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void BatchToSpace::set_crops(const std::vector<std::vector<int64_t>> &crops) {
|
||||
this->AddAttr(kCrops, MakeValue(crops));
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> BatchToSpace::get_crops() const {
|
||||
auto value_ptr = this->GetAttr(kCrops);
|
||||
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
|
||||
}
|
||||
|
||||
AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim = primitive->cast<PrimBatchToSpacePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, prim_name);
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
auto block_size = prim->get_block_size();
|
||||
auto crops = prim->get_crops();
|
||||
auto out_shape = x_shape;
|
||||
for (size_t i = 0; i < 2; ++i) {
|
||||
auto x_block_prod = out_shape[i + 2] * block_size[i];
|
||||
auto crops_sum = crops[i][0] + crops[i][1];
|
||||
CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, "crops sum", 4, prim_name);
|
||||
out_shape[i + 2] = x_block_prod - crops_sum;
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("x_shape[0] % (block_size[0]*block_size[1])",
|
||||
out_shape[0] % (block_size[0] * block_size[1]), kEqual, 0, prim_name);
|
||||
out_shape[0] /= block_size[0] * block_size[1];
|
||||
|
||||
auto ret = input_args[0]->Broaden();
|
||||
ret->set_shape(std::make_shared<abstract::Shape>(out_shape));
|
||||
return ret;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BatchToSpace, prim::kPrimBatchToSpace, BatchToSpaceInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameBatchToSpace, BatchToSpace);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CORE_OPS_BATCH_TO_SPACE_H_
|
||||
#define MINDSPORE_CORE_OPS_BATCH_TO_SPACE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameBatchToSpace = "BatchToSpace";
|
||||
class BatchToSpace : public PrimitiveC {
|
||||
public:
|
||||
BatchToSpace() : PrimitiveC(kNameBatchToSpace) {}
|
||||
~BatchToSpace() = default;
|
||||
MS_DECLARE_PARENT(BatchToSpace, PrimitiveC);
|
||||
void Init(const std::vector<int64_t> &block_size, const std::vector<std::vector<int64_t>> &crops);
|
||||
void set_block_size(const std::vector<int64_t> &block_size);
|
||||
void set_crops(const std::vector<std::vector<int64_t>> &crops);
|
||||
std::vector<int64_t> get_block_size() const;
|
||||
std::vector<std::vector<int64_t>> get_crops() const;
|
||||
};
|
||||
|
||||
AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimBatchToSpacePtr = std::shared_ptr<BatchToSpace>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_BATCH_TO_SPACE_H_
|
|
@ -0,0 +1,109 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/batch_to_space_nd.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto batch_prim = primitive->cast<PrimBatchToSpaceNDPtr>();
|
||||
MS_EXCEPTION_IF_NULL(batch_prim);
|
||||
auto prim_name = batch_prim->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
auto out_shape = x_shape;
|
||||
int64_t block_shape_prod = 1;
|
||||
int64_t offset = 2;
|
||||
auto block_shape = batch_prim->get_block_shape();
|
||||
auto crops = batch_prim->get_crops();
|
||||
int64_t size = block_shape.size();
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
block_shape_prod = block_shape_prod * block_shape[i];
|
||||
auto x_block_prod = out_shape[i + offset] * block_shape[i];
|
||||
auto crops_sum = crops[i][0] + crops[i][1];
|
||||
CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, "crops sum", crops_sum, prim_name);
|
||||
out_shape[i + offset] = x_block_prod - crops_sum;
|
||||
}
|
||||
if (out_shape[0] % block_shape_prod != 0) {
|
||||
MS_EXCEPTION(ValueError) << prim_name << " input_x dimension 0 " << out_shape[0]
|
||||
<< " should be divisible by block_shape_prod " << block_shape_prod;
|
||||
}
|
||||
out_shape[0] = int64_t(floor(out_shape[0] / block_shape_prod));
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto infer_type = input_args[0]->BuildType();
|
||||
return infer_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void BatchToSpaceND::set_crops(std::vector<std::vector<int64_t>> crops) {
|
||||
CheckAndConvertUtils::CheckInteger(kCrops, crops.size(), kEqual, 2, this->name());
|
||||
int64_t h = crops.size();
|
||||
int64_t w = crops[0].size();
|
||||
std::vector<int64_t> temp_w = {2, 2};
|
||||
CheckAndConvertUtils::Check(kCrops, {h, w}, kEqual, "paddings_shape(2,2)", temp_w, this->name());
|
||||
for (int64_t i = 0; i < h; i++) {
|
||||
for (int64_t j = 0; j < w; j++) {
|
||||
CheckAndConvertUtils::CheckInteger(kCrops, crops[i][j], kGreaterEqual, 0, this->name());
|
||||
}
|
||||
}
|
||||
this->AddAttr(kCrops, MakeValue(crops));
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> BatchToSpaceND::get_crops() const {
|
||||
auto value_ptr = GetAttr(kCrops);
|
||||
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
|
||||
}
|
||||
void BatchToSpaceND::set_block_shape(std::vector<int64_t> block_shape) {
|
||||
CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape.size(), kEqual, 2, this->name());
|
||||
for (int64_t i = 0; i < (int64_t)block_shape.size(); i++) {
|
||||
CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape[i], kGreaterEqual, 1, this->name());
|
||||
}
|
||||
this->AddAttr(kBlockShape, MakeValue(block_shape));
|
||||
}
|
||||
|
||||
std::vector<int64_t> BatchToSpaceND::get_block_shape() const {
|
||||
auto value_ptr = GetAttr(kBlockShape);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void BatchToSpaceND::Init(std::vector<int64_t> block_shape, std::vector<std::vector<int64_t>> crops) {
|
||||
this->set_crops(crops);
|
||||
this->set_block_shape(block_shape);
|
||||
}
|
||||
AbstractBasePtr BatchToSpaceNDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BatchToSpaceND, prim::kPrimBatchToSpaceND, BatchToSpaceNDInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameBatchToSpaceND, BatchToSpaceND);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_BATCH_TO_SPACE_ND_H_
|
||||
#define MINDSPORE_CORE_OPS_BATCH_TO_SPACE_ND_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameBatchToSpaceND = "BatchToSpaceND";
|
||||
class BatchToSpaceND : public PrimitiveC {
|
||||
public:
|
||||
BatchToSpaceND() : PrimitiveC(kNameBatchToSpaceND) {}
|
||||
~BatchToSpaceND() = default;
|
||||
MS_DECLARE_PARENT(BatchToSpaceND, PrimitiveC);
|
||||
void Init(std::vector<int64_t> block_shape, std::vector<std::vector<int64_t>> crops);
|
||||
void set_crops(std::vector<std::vector<int64_t>> crops);
|
||||
void set_block_shape(std::vector<int64_t> block_shape);
|
||||
std::vector<int64_t> get_block_shape() const;
|
||||
std::vector<std::vector<int64_t>> get_crops() const;
|
||||
};
|
||||
AbstractBasePtr BatchToSpaceNDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimBatchToSpaceNDPtr = std::shared_ptr<BatchToSpaceND>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_BATCH_TO_SPACE_ND_H_
|
|
@ -14,12 +14,15 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/bias_add.h"
|
||||
#include "ops/bias_add.h"
|
||||
#include <memory>
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
// Add
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void BiasAdd::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
this->AddAttr(kFormat, MakeValue(f));
|
||||
|
@ -29,5 +32,7 @@ Format BiasAdd::get_format() const {
|
|||
return Format(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
void BiasAdd::Init(const Format &format) { this->set_format(format); }
|
||||
// Add
|
||||
REGISTER_PRIMITIVE_C(kNameBiasAdd, BiasAdd);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,17 +14,20 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_BIASADD_H_
|
||||
#define MINDSPORE_CORE_C_OPS_BIASADD_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_BIAS_ADD_H_
|
||||
#define MINDSPORE_CORE_OPS_BIAS_ADD_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
// Add
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameBiasAdd = "BiasAdd";
|
||||
class BiasAdd : public PrimitiveC {
|
||||
public:
|
||||
|
@ -35,6 +38,7 @@ class BiasAdd : public PrimitiveC {
|
|||
void set_format(const Format &format);
|
||||
Format get_format() const;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_BIASADD_H_
|
||||
#endif // MINDSPORE_CORE_OPS_BIAS_ADD_H_
|
|
@ -0,0 +1,94 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "ops/binary_cross_entropy.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto binary_cross_entropy_prim = primitive->cast<PrimBinaryCrossEntropyPtr>();
|
||||
MS_EXCEPTION_IF_NULL(binary_cross_entropy_prim);
|
||||
auto prim_name = binary_cross_entropy_prim->name();
|
||||
CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto weight_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("weight_shape", input_args[2]->BuildShape(), prim_name);
|
||||
CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
|
||||
std::vector<int64_t> infer_shape;
|
||||
if (weight_shape.size() < 1) {
|
||||
CheckAndConvertUtils::Check("x shape", y_shape, kEqual, "weight shape", weight_shape, prim_name);
|
||||
}
|
||||
if (binary_cross_entropy_prim->get_reduction() != REDUCTION_SUM &&
|
||||
binary_cross_entropy_prim->get_reduction() != MEAN) {
|
||||
infer_shape = {x_shape.begin(), infer_shape.end()};
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(infer_shape);
|
||||
}
|
||||
|
||||
TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInteger("binary_cross_entropy_infer", input_args.size(), kEqual, 3, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x_shape", input_args[0]->BuildType());
|
||||
types.emplace("y_shape", input_args[1]->BuildType());
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
if (input_args[3]->BuildType() != nullptr) {
|
||||
types.emplace("x_shape", input_args[0]->BuildType());
|
||||
types.emplace("weight_shape", input_args[2]->BuildType());
|
||||
infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
return TypeIdToType(infer_type);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void BinaryCrossEntropy::set_reduction(const Reduction &reduction) {
|
||||
int64_t swi = reduction;
|
||||
this->AddAttr(kReduction, MakeValue(swi));
|
||||
}
|
||||
|
||||
Reduction BinaryCrossEntropy::get_reduction() const {
|
||||
auto value_ptr = GetAttr(kReduction);
|
||||
return Reduction(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
void BinaryCrossEntropy::Init(const Reduction &reduction) { this->set_reduction(reduction); }
|
||||
|
||||
AbstractBasePtr BinaryCrossEntropyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(BinaryCrossEntroyInferType(primitive, input_args),
|
||||
BinaryCrossEntroyInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BinaryCrossEntropy, prim::kPrimBinaryCrossEntropy, BinaryCrossEntropyInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameBinaryCrossEntropy, BinaryCrossEntropy);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,25 +14,32 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_GRAD_H_
|
||||
#define MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_GRAD_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_BINARY_CROSS_ENTROPY_H_
|
||||
#define MINDSPORE_CORE_OPS_BINARY_CROSS_ENTROPY_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameBinaryCrossEntropy = "BinaryCrossEntropy";
|
||||
class BinaryCrossEntropy : public PrimitiveC {
|
||||
public:
|
||||
BinaryCrossEntropy() : PrimitiveC(kNameBinaryCrossEntropy) {}
|
||||
~BinaryCrossEntropy() = default;
|
||||
MS_DECLARE_PARENT(BinaryCrossEntropy, PrimitiveC);
|
||||
void Init(const std::string &reduction = "mean");
|
||||
void set_reduction(const std::string &reduction);
|
||||
std::string get_reduction() const;
|
||||
void Init(const Reduction &reduction = MEAN);
|
||||
void set_reduction(const Reduction &reduction);
|
||||
Reduction get_reduction() const;
|
||||
};
|
||||
AbstractBasePtr BinaryCrossEntropyGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimBinaryCrossEntropyPtr = std::shared_ptr<BinaryCrossEntropy>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_C_OPS_BINARY_CROSS_ENTROPY_GRAD_H_
|
||||
#endif // MINDSPORE_CORE_OPS_BINARY_CROSS_ENTROPY_H_
|
|
@ -14,13 +14,14 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/black_box.h"
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "ops/black_box.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
void BlackBox::Init(const std::string &id, int64_t size, const std::vector<int64_t> &address) {
|
||||
namespace ops {
|
||||
void BlackBox::Init(const std::string &id, const int64_t size, const std::vector<int64_t> &address) {
|
||||
this->set_id(id);
|
||||
this->set_size(size);
|
||||
this->set_address(address);
|
||||
|
@ -33,7 +34,7 @@ std::string BlackBox::get_id() const {
|
|||
return GetValue<std::string>(value_ptr);
|
||||
}
|
||||
|
||||
void BlackBox::set_size(int64_t size) { this->AddAttr(kSize, MakeValue(size)); }
|
||||
void BlackBox::set_size(const int64_t size) { this->AddAttr(kSize, MakeValue(size)); }
|
||||
|
||||
int64_t BlackBox::get_size() const {
|
||||
auto value_ptr = this->GetAttr(kSize);
|
||||
|
@ -47,4 +48,5 @@ std::vector<int64_t> BlackBox::get_address() const {
|
|||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameBlackBox, BlackBox);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,25 +14,26 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_BLACKBOX_H_
|
||||
#define MINDSPORE_CORE_C_OPS_BLACKBOX_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_BLACK_BOX_H_
|
||||
#define MINDSPORE_CORE_OPS_BLACK_BOX_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameBlackBox = "BlackBox";
|
||||
class BlackBox : public PrimitiveC {
|
||||
public:
|
||||
BlackBox() : PrimitiveC(kNameBlackBox) {}
|
||||
~BlackBox() = default;
|
||||
MS_DECLARE_PARENT(BlackBox, PrimitiveC);
|
||||
void Init(const std::string &id, int64_t size, const std::vector<int64_t> &address);
|
||||
void Init(const std::string &id, const int64_t size, const std::vector<int64_t> &address);
|
||||
void set_id(const std::string &id);
|
||||
void set_size(int64_t size);
|
||||
void set_size(const int64_t size);
|
||||
void set_address(const std::vector<int64_t> &address);
|
||||
std::string get_id() const;
|
||||
int64_t get_size() const;
|
||||
|
@ -40,6 +41,7 @@ class BlackBox : public PrimitiveC {
|
|||
};
|
||||
|
||||
using PrimBlackBoxPtr = std::shared_ptr<BlackBox>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_BLACKBOX_H_
|
||||
#endif // MINDSPORE_CORE_OPS_BLACK_BOX_H_
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/broadcast.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void Broadcast::Init(const int64_t root_rank, const std::string &group) {
|
||||
this->set_root_rank(root_rank);
|
||||
this->set_group(group);
|
||||
}
|
||||
void Broadcast::set_root_rank(const int64_t root_rank) { this->AddAttr(kKeepProb, MakeValue(root_rank)); }
|
||||
|
||||
void Broadcast::set_group(const std::string &group) {
|
||||
CheckAndConvertUtils::CheckString(kGroup, group, {"hccl_world_group", "hccl_world_group"}, this->name());
|
||||
this->AddAttr(kGroup, MakeValue(group));
|
||||
}
|
||||
int64_t Broadcast::get_root_rank() const {
|
||||
auto value_ptr = this->GetAttr(kRootRank);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
std::string Broadcast::get_group() const {
|
||||
auto value_ptr = this->GetAttr(kGroup);
|
||||
return GetValue<std::string>(value_ptr);
|
||||
}
|
||||
AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto broadcast_prim = primitive->cast<PrimBroadcast>();
|
||||
MS_EXCEPTION_IF_NULL(broadcast_prim);
|
||||
auto prim_name = broadcast_prim->name();
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
// infer shape
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
// infer type
|
||||
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
std::vector<TypePtr> output_types;
|
||||
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
|
||||
for (size_t i = 0; i < input_args.size(); i++) {
|
||||
auto out_type = input_args[i]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
output_types.push_back(out_type);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("index_type", out_type, valid_types, prim_name);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTensor>(x_type, in_shape);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Broadcast, prim::kPrimBroadcast, BroadcastInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameBroadcast, Broadcast);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,27 +14,34 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_BROADCAST_H_
|
||||
#define MINDSPORE_CORE_C_OPS_BROADCAST_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_BROADCAST_H_
|
||||
#define MINDSPORE_CORE_OPS_BROADCAST_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameBroadcast = "Broadcast";
|
||||
class Broadcast : public PrimitiveC {
|
||||
public:
|
||||
Broadcast() : PrimitiveC(kNameBroadcast) {}
|
||||
~Broadcast() = default;
|
||||
MS_DECLARE_PARENT(Broadcast, PrimitiveC);
|
||||
void Init(int64_t root_rank, const std::string &group = "hccl_world_group");
|
||||
void set_root_rank(int64_t root_rank);
|
||||
void Init(const int64_t root_rank, const std::string &group = "hccl_world_group");
|
||||
void set_root_rank(const int64_t root_rank);
|
||||
void set_group(const std::string &group);
|
||||
int64_t get_root_rank();
|
||||
int64_t get_root_rank() const;
|
||||
std::string get_group() const;
|
||||
};
|
||||
AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimBroadcast = std::shared_ptr<Broadcast>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_BROADCAST_H_
|
||||
#endif // MINDSPORE_CORE_OPS_BROADCAST_H_
|
|
@ -0,0 +1,88 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include "ops/broadcast_to.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto broad_cast_to = primitive->cast<PrimBroadcastToPtr>();
|
||||
MS_EXCEPTION_IF_NULL(broad_cast_to);
|
||||
auto prim_name = broad_cast_to->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto input_x = broad_cast_to->get_shape();
|
||||
int64_t outer_dim_offset = input_x.size() - x_shape.size();
|
||||
CheckAndConvertUtils::Check("x shape", x_shape, kLessEqual, "input_x", input_x, prim_name);
|
||||
bool flag = true;
|
||||
if (input_x.end() == find(input_x.begin(), input_x.end(), -1)) {
|
||||
flag = false;
|
||||
} else {
|
||||
flag = true;
|
||||
}
|
||||
if (flag == true) {
|
||||
for (int64_t i = 0; i < (int64_t)input_x.size(); i++) {
|
||||
if (input_x[i] == -1) {
|
||||
if (i < outer_dim_offset) {
|
||||
MS_EXCEPTION(ValueError) << " -1 in init shape is in an incompatible "
|
||||
"location with given input tensor, -1 index in init shape: "
|
||||
<< i << " but -1 can only be in index" << x_shape.size()
|
||||
<< "onwards for this input.";
|
||||
}
|
||||
input_x[i] = x_shape[i - outer_dim_offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
std::reverse(input_x.begin(), input_x.end());
|
||||
return std::make_shared<abstract::Shape>(input_x);
|
||||
}
|
||||
|
||||
TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)};
|
||||
CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name());
|
||||
auto infer_dtype = input_args[0]->BuildType()->type_id();
|
||||
return TypeIdToType(infer_dtype);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void BroadcastTo::Init(const std::vector<int64_t> &shape) { set_shape(shape); }
|
||||
|
||||
void BroadcastTo::set_shape(const std::vector<int64_t> &shape) {
|
||||
CheckAndConvertUtils::CheckInteger(kShapeSize, shape.size(), kGreaterThan, 0, name());
|
||||
AddAttr(kShape, MakeValue(shape));
|
||||
}
|
||||
|
||||
std::vector<int64_t> BroadcastTo::get_shape() const {
|
||||
auto value_ptr = GetAttr(kShape);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(BroadcastToInferType(primitive, input_args),
|
||||
BroadcastToInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastTo, prim::kPrimBroadcastTo, BroadcastToInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameBroadcastTo, BroadcastTo);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,18 +14,19 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_BROADCAST_H_
|
||||
#define MINDSPORE_CORE_C_OPS_BROADCAST_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_BROADCAST_TO_H_
|
||||
#define MINDSPORE_CORE_OPS_BROADCAST_TO_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameBroadcastTo = "BroadcastTo";
|
||||
class BroadcastTo : public PrimitiveC {
|
||||
public:
|
||||
|
@ -41,6 +42,7 @@ AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimBroadcastToPtr = std::shared_ptr<BroadcastTo>;
|
||||
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_BROADCAST_H_
|
||||
#endif // MINDSPORE_CORE_OPS_BROADCAST_TO_H_
|
|
@ -14,8 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/cast.h"
|
||||
#include "ops/cast.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
REGISTER_PRIMITIVE_C(kNameCast, Cast);
|
||||
}
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,17 +14,18 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_CAST_H_
|
||||
#define MINDSPORE_CORE_C_OPS_CAST_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_CAST_H_
|
||||
#define MINDSPORE_CORE_OPS_CAST_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCast = "Cast";
|
||||
class Cast : public PrimitiveC {
|
||||
public:
|
||||
|
@ -35,5 +36,6 @@ class Cast : public PrimitiveC {
|
|||
AbstractBasePtr CastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimCast = std::shared_ptr<Cast>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_C_OPS_CAST_H_
|
||||
#endif // MINDSPORE_CORE_OPS_CAST_H_
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/ceil.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
||||
AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Ceil");
|
||||
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
|
||||
auto infer_type = input_args[0]->BuildType();
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name());
|
||||
MS_EXCEPTION_IF_NULL(infer_type);
|
||||
auto tensor_type = infer_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto data_type = tensor_type->element();
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
return std::make_shared<abstract::AbstractTensor>(data_type, x_shape);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Ceil, prim::kPrimCeil, CeilInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameCeil, Ceil);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,26 +14,29 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_CEIL_H_
|
||||
#define MINDSPORE_CORE_C_OPS_CEIL_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_CEIL_H_
|
||||
#define MINDSPORE_CORE_OPS_CEIL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCeil = "Ceil";
|
||||
class Ceil : public PrimitiveC {
|
||||
public:
|
||||
Ceil() : PrimitiveC(kNameCeil) { InitIOName({"x"}, {"y"}); }
|
||||
~Ceil() = default;
|
||||
MS_DECLARE_PARENT(Ceil, PrimitiveC);
|
||||
void init() {}
|
||||
};
|
||||
|
||||
AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimCeil = std::shared_ptr<Ceil>;
|
||||
using PrimCeilPtr = std::shared_ptr<Ceil>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_C_OPS_CEIL_H_
|
||||
#endif // MINDSPORE_CORE_OPS_CEIL_H_
|
|
@ -14,12 +14,13 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/clip.h"
|
||||
#include "ops/clip.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void Clip::Init(const float max, const float min) {
|
||||
this->set_max(max);
|
||||
this->set_min(min);
|
||||
|
@ -39,4 +40,5 @@ float Clip::get_min() const {
|
|||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameClip, Clip);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -13,15 +13,16 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CORE_C_OPS_CLIP_H_
|
||||
#define MINDSPORE_CORE_C_OPS_CLIP_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_CLIP_H_
|
||||
#define MINDSPORE_CORE_OPS_CLIP_H_
|
||||
#include <memory>
|
||||
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameClip = "Clip";
|
||||
class Clip : public PrimitiveC {
|
||||
public:
|
||||
|
@ -36,6 +37,7 @@ class Clip : public PrimitiveC {
|
|||
};
|
||||
|
||||
using PrimClipPtr = std::shared_ptr<Clip>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_CLIP_H_
|
||||
#endif // MINDSPORE_CORE_OPS_CLIP_H_
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "ops/concat.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
||||
void Concat::Init(const int64_t axis) { this->set_axis(axis); }
|
||||
int64_t Concat::get_axis() const {
|
||||
auto value_ptr = this->GetAttr(kAxis);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void Concat::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); }
|
||||
|
||||
AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim = primitive->cast<PrimConcatPtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto input_tuple = input_args[0]->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_tuple);
|
||||
auto elements = input_tuple->elements();
|
||||
CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name);
|
||||
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(element0);
|
||||
auto element0_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name);
|
||||
auto element0_rank = SizeToLong(element0_shape.size());
|
||||
auto axis = prim->get_axis();
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank},
|
||||
prim_name);
|
||||
axis = axis < 0 ? axis + element0_rank : axis;
|
||||
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("element0", element0->BuildType());
|
||||
int64_t all_shp = element0_shape[axis];
|
||||
for (size_t i = 1; i < elements.size(); ++i) {
|
||||
std::string elementi = "element" + std::to_string(i);
|
||||
auto elementi_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name);
|
||||
CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(),
|
||||
prim_name);
|
||||
for (int64_t j = 0; j < element0_rank; ++j) {
|
||||
if (j != axis && elementi_shape[j] != element0_shape[j]) {
|
||||
MS_LOG(EXCEPTION) << "element " << i << " shape in input can not concat with first element.";
|
||||
}
|
||||
}
|
||||
all_shp = all_shp == -1 || elementi_shape[axis] == -1 ? -1 : all_shp + elementi_shape[axis];
|
||||
types.emplace(elementi, elements[i]->BuildType());
|
||||
}
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, all_types, prim_name);
|
||||
auto ret_shape = element0_shape;
|
||||
ret_shape[axis] = all_shp;
|
||||
|
||||
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(infer_type),
|
||||
std::make_shared<abstract::Shape>(ret_shape));
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Concat, prim::kPrimConcat, ConcatInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameConcat, Concat);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,29 +14,31 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_CONCAT_H_
|
||||
#define MINDSPORE_CORE_C_OPS_CONCAT_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_CONCAT_H_
|
||||
#define MINDSPORE_CORE_OPS_CONCAT_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameConcat = "Concat";
|
||||
class Concat : public PrimitiveC {
|
||||
public:
|
||||
Concat() : PrimitiveC(kNameConcat) {}
|
||||
~Concat() = default;
|
||||
MS_DECLARE_PARENT(Concat, PrimitiveC);
|
||||
void Init(int64_t axis = 0);
|
||||
void set_axis(int64_t axis);
|
||||
void Init(const int64_t axis = 0);
|
||||
void set_axis(const int64_t axis);
|
||||
int64_t get_axis() const;
|
||||
};
|
||||
AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimConcatPtr = std::shared_ptr<Concat>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_C_OPS_CONCAT_H_
|
||||
#endif // MINDSPORE_CORE_OPS_CONCAT_H_
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/constant.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto x = input_args[0]->BuildShape();
|
||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_element);
|
||||
return shape_element;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return TypeIdToType(infer_type);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr ConstantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Constant, prim::kPrimConstant, ConstantInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameConstant, Constant);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CONSTANT_H_
|
||||
#define MINDSPORE_CORE_OPS_CONSTANT_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameConstant = "Constant";
|
||||
class Constant : public PrimitiveC {
|
||||
public:
|
||||
Constant() : PrimitiveC(kNameConstant) {}
|
||||
~Constant() = default;
|
||||
MS_DECLARE_PARENT(Constant, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
|
||||
AbstractBasePtr ConstantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimConstantPtr = std::shared_ptr<Constant>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CONSTANT_H_
|
|
@ -14,12 +14,30 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/constant_of_shape.h"
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "ops/constant_of_shape.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kEqual, 1, "ConstantOfShape");
|
||||
auto input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "ConstantOfShape");
|
||||
return std::make_shared<abstract::Shape>(input_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto constant_prim = primitive->cast<PrimConstantOfShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(constant_prim);
|
||||
auto data_type = TypeId(constant_prim->get_data_type());
|
||||
return TypeIdToType(data_type);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ConstantOfShape::Init(int64_t data_type, const std::vector<float> &value) {
|
||||
this->set_data_type(data_type);
|
||||
this->set_value(value);
|
||||
|
@ -38,5 +56,12 @@ std::vector<float> ConstantOfShape::get_value() const {
|
|||
auto value_ptr = this->GetAttr(kValue);
|
||||
return GetValue<std::vector<float>>(value_ptr);
|
||||
}
|
||||
AbstractBasePtr ConstantOfShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ConstantOfShape, prim::kPrimConstantOfShape, ConstantOfShapeInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameConstantOfShape, ConstantOfShape);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,15 +14,16 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_CONSTANTOFSHAPE_H_
|
||||
#define MINDSPORE_CORE_C_OPS_CONSTANTOFSHAPE_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_CONSTANT_OF_SHAPE_H_
|
||||
#define MINDSPORE_CORE_OPS_CONSTANT_OF_SHAPE_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameConstantOfShape = "ConstantOfShape";
|
||||
class ConstantOfShape : public PrimitiveC {
|
||||
public:
|
||||
|
@ -36,7 +37,10 @@ class ConstantOfShape : public PrimitiveC {
|
|||
std::vector<float> get_value() const;
|
||||
};
|
||||
|
||||
AbstractBasePtr ConstantOfShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimConstantOfShapePtr = std::shared_ptr<ConstantOfShape>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_CONSTANTOFSHAPE_H_
|
||||
#endif // MINDSPORE_CORE_OPS_CONSTANT_OF_SHAPE_H_
|
|
@ -14,19 +14,21 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/control_depend.h"
|
||||
#include "ops/control_depend.h"
|
||||
|
||||
namespace mindspore {
|
||||
void ControlDepend::Init(int64_t depend_mode) { this->set_depend_mode(depend_mode); }
|
||||
namespace ops {
|
||||
void ControlDepend::Init(const int64_t depend_mode) { this->set_depend_mode(depend_mode); }
|
||||
|
||||
void ControlDepend::set_depend_mode(int64_t depend_mode) {
|
||||
CheckAndConvertUtils::CheckInRange(kDependMode, depend_mode, kIncludeBoth, {0, 1}, name());
|
||||
void ControlDepend::set_depend_mode(const int64_t depend_mode) {
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>(kDependMode, depend_mode, kIncludeBoth, {0, 1}, name());
|
||||
AddAttr(kDependMode, MakeValue(depend_mode));
|
||||
}
|
||||
|
||||
int64_t ControlDepend::get_depend_mode() {
|
||||
int64_t ControlDepend::get_depend_mode() const {
|
||||
auto value_ptr = GetAttr(kDependMode);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameControlDepend, ControlDepend);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -14,29 +14,29 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_CONTROL_DEPEND_H_
|
||||
#define MINDSPORE_CORE_C_OPS_CONTROL_DEPEND_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_CONTROL_DEPEND_H_
|
||||
#define MINDSPORE_CORE_OPS_CONTROL_DEPEND_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameControlDepend = "ControlDepend";
|
||||
class ControlDepend : public PrimitiveC {
|
||||
public:
|
||||
ControlDepend() : PrimitiveC(kNameControlDepend) {}
|
||||
~ControlDepend() = default;
|
||||
MS_DECLARE_PARENT(ControlDepend, PrimitiveC);
|
||||
void Init(int64_t depend_mode);
|
||||
void set_depend_mode(int64_t depend_mode);
|
||||
int64_t get_depend_mode();
|
||||
void Init(const int64_t depend_mode);
|
||||
void set_depend_mode(const int64_t depend_mode = 0);
|
||||
int64_t get_depend_mode() const;
|
||||
};
|
||||
AbstractBasePtr ControlDependInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimControlDepend = std::shared_ptr<ControlDepend>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_C_OPS_CONTROl_DEPEND_H_
|
||||
#endif // MINDSPORE_CORE_OPS_CONTROl_DEPEND_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/conv2d.h"
|
||||
#include "ops/conv2d.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
@ -23,102 +23,29 @@
|
|||
#include "ir/dtype/tensor_type.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "ops/control_depend.h"
|
||||
|
||||
namespace mindspore {
|
||||
Conv2D::Conv2D() : PrimitiveC(kNameConv2D) { InitIOName({"x", "w"}, {"output"}); }
|
||||
|
||||
void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode,
|
||||
const std::string &pad_mode, const std::vector<int64_t> &pad, const std::vector<int64_t> &stride,
|
||||
const std::vector<int64_t> &dilation, int64_t group) {
|
||||
auto prim_name = this->name();
|
||||
this->AddAttr("data_format", MakeValue("NCHW"));
|
||||
this->AddAttr("offset_a", MakeValue(static_cast<int64_t>(0)));
|
||||
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name));
|
||||
this->set_stride(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), true, true));
|
||||
this->set_dilation(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), true, true));
|
||||
this->set_pad_mode(CheckAndConvertUtils::CheckString(kPadMode, pad_mode, {"valid", "same", "pad"}, prim_name));
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, prim_name);
|
||||
if (pad_mode == "pad") {
|
||||
for (auto item : pad) {
|
||||
CheckAndConvertUtils::Check("pad_item", item, kGreaterEqual, "zeros_list", 0, prim_name);
|
||||
}
|
||||
} else {
|
||||
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name);
|
||||
}
|
||||
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true));
|
||||
this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 1, prim_name));
|
||||
this->set_out_channel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name));
|
||||
this->set_group(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2D::get_kernel_size() const {
|
||||
auto value_ptr = GetAttr(kKernelSize);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2D::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2D::get_dilation() const {
|
||||
auto value_ptr = GetAttr(kDilation);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
std::string Conv2D::get_pad_mode() const {
|
||||
auto value_ptr = this->GetAttr(kPadMode);
|
||||
return GetValue<string>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2D::get_pad() const {
|
||||
auto value_ptr = this->GetAttr(kPad);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
std::vector<int64_t> Conv2D::get_pad_list() const {
|
||||
auto value_ptr = this->GetAttr(kPadList);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
int64_t Conv2D::get_mode() const {
|
||||
auto value_ptr = this->GetAttr(kMode);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t Conv2D::get_group() const {
|
||||
auto value_ptr = this->GetAttr(kGroup);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
int64_t Conv2D::get_output_channel() const {
|
||||
auto value_ptr = this->GetAttr(kOutputChannel);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
this->AddAttr(kKernelSize, MakeValue(kernel_size));
|
||||
}
|
||||
void Conv2D::set_stride(const std::vector<int64_t> &stride) { this->AddAttr(kStride, MakeValue(stride)); }
|
||||
void Conv2D::set_dilation(const std::vector<int64_t> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); }
|
||||
void Conv2D::set_pad_mode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); }
|
||||
void Conv2D::set_pad(const std::vector<int64_t> &pad) { this->AddAttr(kPad, MakeValue(pad)); }
|
||||
void Conv2D::set_mode(int64_t mode) { this->AddAttr(kMode, MakeValue(mode)); }
|
||||
void Conv2D::set_group(int64_t group) { this->AddAttr(kGroup, MakeValue(group)); }
|
||||
void Conv2D::set_out_channel(int64_t output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); }
|
||||
void Conv2D::set_pad_list(const std::vector<int64_t> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); }
|
||||
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto conv_prim = primitive->cast<PrimConv2dPtr>();
|
||||
MS_EXCEPTION_IF_NULL(conv_prim);
|
||||
auto prim_name = conv_prim->name();
|
||||
CheckAndConvertUtils::CheckInRange("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
||||
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
|
||||
if (conv_prim->get_format() == NHWC) {
|
||||
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
|
||||
w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]};
|
||||
}
|
||||
|
||||
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->get_group(), kEqual, "w_shape[1]",
|
||||
w_shape[1], conv_prim->name());
|
||||
auto out_channel = conv_prim->get_output_channel();
|
||||
auto out_channel = conv_prim->get_out_channel();
|
||||
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name());
|
||||
std::vector<int64_t> temp_w;
|
||||
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
|
||||
|
@ -137,10 +64,10 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
int64_t w_out = -1;
|
||||
std::vector<int64_t> pad_list(4, 0);
|
||||
auto pad_mode = conv_prim->get_pad_mode();
|
||||
if (pad_mode == "valid") {
|
||||
if (pad_mode == VALID) {
|
||||
h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
|
||||
w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
|
||||
} else if (pad_mode == "same") {
|
||||
} else if (pad_mode == SAME) {
|
||||
h_out = ceil(x_shape[2] / stride_h);
|
||||
w_out = ceil(x_shape[3] / stride_w);
|
||||
|
||||
|
@ -153,7 +80,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
auto pad_left = floor(pad_needed_w / 2);
|
||||
pad_list.emplace_back(pad_left);
|
||||
pad_list.emplace_back(pad_needed_h - pad_left);
|
||||
} else if (pad_mode == "pad") {
|
||||
} else if (pad_mode == PAD) {
|
||||
std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list));
|
||||
auto pad_top = conv_prim->get_pad()[0];
|
||||
auto pad_bottom = conv_prim->get_pad()[1];
|
||||
|
@ -165,13 +92,17 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
h_out = floor(h_out);
|
||||
w_out = floor(w_out);
|
||||
}
|
||||
conv_prim->set_pad_list(pad_list);
|
||||
conv_prim->set_pad(pad_list);
|
||||
std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out};
|
||||
if (conv_prim->get_format() == NHWC) {
|
||||
out_shape = {x_shape[0], h_out, w_out, out_channel};
|
||||
}
|
||||
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeBoth, {2, 3}, prim->name());
|
||||
CheckAndConvertUtils::CheckInRange<size_t>("", input_args.size(), kIncludeBoth, {2, 3}, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -186,12 +117,121 @@ TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
|
|||
}
|
||||
return TypeIdToType(infer_type);
|
||||
}
|
||||
} // namespace
|
||||
void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, const PadMode &pad_mode,
|
||||
const std::vector<int64_t> &pad, const std::vector<int64_t> &stride,
|
||||
const std::vector<int64_t> &dilation, int64_t group, const Format &format) {
|
||||
set_kernel_size(kernel_size);
|
||||
set_stride(stride);
|
||||
set_dilation(dilation);
|
||||
set_pad(pad);
|
||||
set_pad_mode(pad_mode);
|
||||
set_mode(mode);
|
||||
set_out_channel(out_channel);
|
||||
set_group(group);
|
||||
set_format(format);
|
||||
}
|
||||
|
||||
void Conv2D::set_out_channel(int64_t out_channel) {
|
||||
AddAttr(kOutChannel,
|
||||
MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
|
||||
}
|
||||
|
||||
void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, name())));
|
||||
}
|
||||
|
||||
void Conv2D::set_stride(const std::vector<int64_t> &stride) {
|
||||
AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true)));
|
||||
}
|
||||
|
||||
void Conv2D::set_dilation(const std::vector<int64_t> &dilation) {
|
||||
AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true)));
|
||||
}
|
||||
|
||||
void Conv2D::set_pad_mode(const PadMode &pad_mode) {
|
||||
std::vector<int64_t> pad = get_pad();
|
||||
if (pad_mode == PAD) {
|
||||
for (auto item : pad) {
|
||||
CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, "zeros_list", 0, name());
|
||||
}
|
||||
} else {
|
||||
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, name());
|
||||
}
|
||||
int64_t swi = pad_mode;
|
||||
AddAttr(kPadMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
void Conv2D::set_pad(const std::vector<int64_t> &pad) {
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
|
||||
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true)));
|
||||
}
|
||||
|
||||
void Conv2D::set_mode(int64_t mode) {
|
||||
AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
|
||||
}
|
||||
|
||||
void Conv2D::set_group(int64_t group) {
|
||||
AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
|
||||
}
|
||||
|
||||
void Conv2D::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
AddAttr(kFormat, MakeValue(f));
|
||||
}
|
||||
|
||||
int64_t Conv2D::get_out_channel() const {
|
||||
auto value_ptr = GetAttr(kOutChannel);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2D::get_kernel_size() const {
|
||||
auto value_ptr = GetAttr(kKernelSize);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2D::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2D::get_dilation() const {
|
||||
auto value_ptr = GetAttr(kDilation);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
PadMode Conv2D::get_pad_mode() const {
|
||||
auto value_ptr = GetAttr(kPadMode);
|
||||
return PadMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2D::get_pad() const {
|
||||
auto value_ptr = GetAttr(kPad);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t Conv2D::get_mode() const {
|
||||
auto value_ptr = GetAttr(kMode);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t Conv2D::get_group() const {
|
||||
auto value_ptr = GetAttr(kGroup);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
Format Conv2D::get_format() const {
|
||||
auto value_ptr = GetAttr(kFormat);
|
||||
return Format(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args),
|
||||
Conv2dInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameConv2D, Conv2D);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,49 +14,52 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_CONV2D_H_
|
||||
#define MINDSPORE_CORE_C_OPS_CONV2D_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_CONV2D_H_
|
||||
#define MINDSPORE_CORE_OPS_CONV2D_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameConv2D = "Conv2D";
|
||||
class Conv2D : public PrimitiveC {
|
||||
public:
|
||||
Conv2D();
|
||||
Conv2D() : PrimitiveC(kNameConv2D) { InitIOName({"x", "w"}, {"output"}); }
|
||||
explicit Conv2D(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x", "w"}, {"output"}); }
|
||||
~Conv2D() = default;
|
||||
MS_DECLARE_PARENT(Conv2D, PrimitiveC);
|
||||
void Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode = 1,
|
||||
const std::string &pad_mode = "valid", const std::vector<int64_t> &pad = {0, 0, 0, 0},
|
||||
const PadMode &pad_mode = VALID, const std::vector<int64_t> &pad = {0, 0, 0, 0},
|
||||
const std::vector<int64_t> &stride = {1, 1, 1, 1}, const std::vector<int64_t> &dilation = {1, 1, 1, 1},
|
||||
int64_t group = 1);
|
||||
std::vector<int64_t> get_kernel_size() const;
|
||||
std::vector<int64_t> get_stride() const;
|
||||
std::vector<int64_t> get_dilation() const;
|
||||
std::string get_pad_mode() const;
|
||||
std::vector<int64_t> get_pad() const;
|
||||
std::vector<int64_t> get_pad_list() const;
|
||||
int64_t get_mode() const;
|
||||
int64_t get_group() const;
|
||||
int64_t get_output_channel() const;
|
||||
int64_t group = 1, const Format &format = NCHW);
|
||||
void set_kernel_size(const std::vector<int64_t> &kernel_size);
|
||||
void set_stride(const std::vector<int64_t> &stride);
|
||||
void set_dilation(const std::vector<int64_t> &dilation);
|
||||
void set_pad_mode(const std::string &pad_mode);
|
||||
void set_pad_mode(const PadMode &pad_mode);
|
||||
void set_pad(const std::vector<int64_t> &pad);
|
||||
void set_mode(int64_t mode);
|
||||
void set_group(int64_t group);
|
||||
void set_out_channel(int64_t output_channel);
|
||||
void set_pad_list(const std::vector<int64_t> &pad_list);
|
||||
void set_out_channel(int64_t out_channel);
|
||||
void set_format(const Format &format);
|
||||
std::vector<int64_t> get_kernel_size() const;
|
||||
std::vector<int64_t> get_stride() const;
|
||||
std::vector<int64_t> get_dilation() const;
|
||||
PadMode get_pad_mode() const;
|
||||
std::vector<int64_t> get_pad() const;
|
||||
int64_t get_mode() const;
|
||||
int64_t get_group() const;
|
||||
int64_t get_out_channel() const;
|
||||
Format get_format() const;
|
||||
};
|
||||
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimConv2dPtr = std::shared_ptr<Conv2D>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_C_OPS_CONV2D_H_
|
||||
#endif // MINDSPORE_CORE_OPS_CONV2D_H_
|
|
@ -0,0 +1,199 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
#include "ops/conv2d_transpose.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto conv2d_transpose_prim = primitive->cast<PrimConv2dTransposePtr>();
|
||||
MS_EXCEPTION_IF_NULL(conv2d_transpose_prim);
|
||||
auto prim_name = conv2d_transpose_prim->name();
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[3]->BuildShape(), prim_name);
|
||||
return std::make_shared<abstract::Shape>(input_shape);
|
||||
}
|
||||
|
||||
TypePtr Conv2dTransposeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInteger("conv2d_transpose_infer", input_args.size(), kEqual, 3, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("doutput_dtye", input_args[0]->BuildType());
|
||||
types.emplace("w_dtype", input_args[1]->BuildType());
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
return TypeIdToType(infer_type);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void Conv2dTranspose::Init(int64_t in_channel, int64_t out_channel, const std::vector<int64_t> &kernel_size,
|
||||
int64_t mode, const PadMode &pad_mode, const std::vector<int64_t> &pad,
|
||||
const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, int64_t group,
|
||||
const Format &format, const std::vector<int64_t> &pad_list) {
|
||||
set_in_channel(in_channel);
|
||||
set_out_channel(out_channel);
|
||||
set_kernel_size(kernel_size);
|
||||
set_mode(mode);
|
||||
set_pad(pad);
|
||||
set_pad_mode(pad_mode);
|
||||
set_stride(stride);
|
||||
set_dilation(dilation);
|
||||
set_group(group);
|
||||
set_format(format);
|
||||
set_pad_list(pad_list);
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_in_channel(int64_t in_channel) {
|
||||
AddAttr(kOutChannel, MakeValue(CheckAndConvertUtils::CheckInteger(kInChannel, in_channel, kGreaterThan, 0, name())));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_out_channel(int64_t out_channel) {
|
||||
AddAttr(kOutChannel,
|
||||
MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
CheckAndConvertUtils::CheckInteger(kKernelSize, kernel_size.size(), kEqual, 2, name());
|
||||
for (int64_t item : kernel_size) {
|
||||
CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name());
|
||||
}
|
||||
AddAttr(kKernelSize, MakeValue(kernel_size));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_stride(const std::vector<int64_t> &stride) {
|
||||
CheckAndConvertUtils::CheckInteger(kStride, stride.size(), kEqual, 2, name());
|
||||
for (int64_t item : stride) {
|
||||
CheckAndConvertUtils::CheckInteger(kStride, item, kGreaterEqual, 1, name());
|
||||
}
|
||||
AddAttr(kStride, MakeValue(stride));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_dilation(const std::vector<int64_t> &dilation) {
|
||||
CheckAndConvertUtils::CheckInteger(kDilation, dilation.size(), kGreaterEqual, 2, name());
|
||||
AddAttr(kDilation, MakeValue(dilation));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) {
|
||||
std::vector<int64_t> pad = get_pad();
|
||||
if (pad_mode == PAD) {
|
||||
for (auto item : pad) {
|
||||
CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, "zeros_list", 0, name());
|
||||
}
|
||||
} else {
|
||||
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, name());
|
||||
}
|
||||
int64_t swi = pad_mode;
|
||||
AddAttr(kPadMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_pad(const std::vector<int64_t> &pad) {
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
|
||||
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true)));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_mode(int64_t mode) {
|
||||
AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_group(int64_t group) {
|
||||
AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
AddAttr(kFormat, MakeValue(f));
|
||||
}
|
||||
|
||||
void Conv2dTranspose::set_pad_list(const std::vector<int64_t> &pad_list) {
|
||||
CheckAndConvertUtils::CheckInteger(kPadList, pad_list.size(), kEqual, 4, name());
|
||||
this->AddAttr(kPadList, MakeValue(pad_list));
|
||||
}
|
||||
|
||||
int64_t Conv2dTranspose::get_in_channel() const {
|
||||
auto value_ptr = GetAttr(kInChannel);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t Conv2dTranspose::get_out_channel() const {
|
||||
auto value_ptr = GetAttr(kOutChannel);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2dTranspose::get_kernel_size() const {
|
||||
auto value_ptr = GetAttr(kKernelSize);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2dTranspose::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2dTranspose::get_dilation() const {
|
||||
auto value_ptr = GetAttr(kDilation);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
PadMode Conv2dTranspose::get_pad_mode() const {
|
||||
auto value_ptr = GetAttr(kPadMode);
|
||||
return PadMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2dTranspose::get_pad() const {
|
||||
auto value_ptr = GetAttr(kPad);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t Conv2dTranspose::get_mode() const {
|
||||
auto value_ptr = GetAttr(kMode);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t Conv2dTranspose::get_group() const {
|
||||
auto value_ptr = GetAttr(kGroup);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
Format Conv2dTranspose::get_format() const {
|
||||
auto value_ptr = GetAttr(kFormat);
|
||||
return Format(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2dTranspose::get_pad_list() const {
|
||||
auto value_ptr = GetAttr(kPadList);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
AbstractBasePtr Conv2dTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(Conv2dTransposeInferType(primitive, input_args),
|
||||
Conv2dTransposeInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2dTranspose, prim::kPrimConv2DTranspose, Conv2dTransposeInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameConv2dTranspose, Conv2dTranspose);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CONV2D_TRANSPOSE_H_
|
||||
#define MINDSPORE_CORE_OPS_CONV2D_TRANSPOSE_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameConv2dTranspose = "Conv2dTranspose";
|
||||
class Conv2dTranspose : public PrimitiveC {
|
||||
public:
|
||||
Conv2dTranspose() : PrimitiveC(kNameConv2dTranspose) {
|
||||
InitIOName({"out_backprop", "filter", "input_sizes"}, {"output"});
|
||||
}
|
||||
explicit Conv2dTranspose(const std::string k_name) : PrimitiveC(k_name) {
|
||||
InitIOName({"out_backprop", "filter", "input_sizes"}, {"output"});
|
||||
}
|
||||
~Conv2dTranspose() = default;
|
||||
MS_DECLARE_PARENT(Conv2dTranspose, PrimitiveC);
|
||||
void Init(int64_t in_channel, int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode = 1,
|
||||
const PadMode &pad_mode = VALID, const std::vector<int64_t> &pad = {0, 0, 0, 0},
|
||||
const std::vector<int64_t> &stride = {1, 1}, const std::vector<int64_t> &dilation = {1, 1},
|
||||
int64_t group = 1, const Format &format = NCHW, const std::vector<int64_t> &pad_list = {0, 0, 0, 0});
|
||||
void set_in_channel(int64_t in_channel);
|
||||
void set_out_channel(int64_t out_channel);
|
||||
virtual void set_kernel_size(const std::vector<int64_t> &kernel_size);
|
||||
void set_stride(const std::vector<int64_t> &stride);
|
||||
virtual void set_dilation(const std::vector<int64_t> &dilation);
|
||||
void set_pad_mode(const PadMode &pad_mode);
|
||||
void set_pad(const std::vector<int64_t> &pad);
|
||||
void set_mode(int64_t mode);
|
||||
void set_group(int64_t group);
|
||||
void set_format(const Format &format);
|
||||
void set_pad_list(const std::vector<int64_t> &pad_list);
|
||||
|
||||
int64_t get_in_channel() const;
|
||||
int64_t get_out_channel() const;
|
||||
std::vector<int64_t> get_kernel_size() const;
|
||||
std::vector<int64_t> get_stride() const;
|
||||
std::vector<int64_t> get_dilation() const;
|
||||
PadMode get_pad_mode() const;
|
||||
std::vector<int64_t> get_pad() const;
|
||||
int64_t get_mode() const;
|
||||
int64_t get_group() const;
|
||||
Format get_format() const;
|
||||
std::vector<int64_t> get_pad_list() const;
|
||||
};
|
||||
AbstractBasePtr Conv2dTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimConv2dTransposePtr = std::shared_ptr<Conv2dTranspose>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_CONV2D_TRANSPOSE_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue