!48662 Fix ACL Set IO Size and infer bug.
Merge pull request !48662 from linqingke/acl
This commit is contained in:
commit
7099307702
|
@ -104,10 +104,13 @@ if(MODE_ASCEND_ALL)
|
|||
find_library(OPTILING optiling ${ASCEND_CANN_OPP_AARCH64_PATH} ${ASCEND_TOOLKIT_OPP_AARCH64_PATH})
|
||||
endif()
|
||||
find_library(ACL_OP_COMPILER acl_op_compiler ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(CANN_KB cann_kb ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(COMPRESS compress ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
find_library(OPSKERNEL opskernel ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
|
||||
target_link_libraries(mindspore_ascend PRIVATE ${RUNTIME_LIB} ${TSDCLIENT} ${DATATRANSFER} ${ERROR_MANAGER}
|
||||
-Wl,--no-as-needed ${OPTILING} ${PLATFORM} ${ACL} ${ACL_TDT_CHANNEL} ${OPT_FEATURE} ${PROFILING}
|
||||
${ACL_OP_COMPILER})
|
||||
${ACL_OP_COMPILER} ${CANN_KB} ${COMPRESS} ${OPSKERNEL})
|
||||
target_link_libraries(mindspore_ascend PRIVATE ${adump_server})
|
||||
endif()
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ INPUT_MAP(ApplyAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, IN
|
|||
ATTR_MAP(ApplyAdagradD) = {{"update_slots", ATTR_DESC(update_slots, AnyTraits<bool>())},
|
||||
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
|
||||
REG_ADPT_DESC(ApplyAdagradD, kNameApplyAdagrad, ADPT_DESC(ApplyAdagradD))
|
||||
REG_ADPT_DESC(ApplyAdagradD, kApplyAdagradDOpName, ADPT_DESC(ApplyAdagradD))
|
||||
|
||||
// ApplyAdagradV2D
|
||||
INPUT_MAP(ApplyAdagradV2D) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}};
|
||||
|
@ -83,7 +83,7 @@ INPUT_MAP(ApplyAddSignD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)},
|
|||
{7, INPUT_DESC(grad)}};
|
||||
ATTR_MAP(ApplyAddSignD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyAddSignD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}};
|
||||
REG_ADPT_DESC(ApplyAddSignD, kNameApplyAddSign, ADPT_DESC(ApplyAddSignD))
|
||||
REG_ADPT_DESC(ApplyAddSignD, kApplyAddSignDOpName, ADPT_DESC(ApplyAddSignD))
|
||||
|
||||
// SparseApplyAdagradV2D
|
||||
INPUT_MAP(SparseApplyAdagradV2D) = {
|
||||
|
@ -93,7 +93,7 @@ ATTR_MAP(SparseApplyAdagradV2D) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())},
|
|||
{"update_slots", ATTR_DESC(update_slots, AnyTraits<bool>())},
|
||||
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(SparseApplyAdagradV2D) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
|
||||
REG_ADPT_DESC(SparseApplyAdagradV2D, kNameSparseApplyAdagradV2, ADPT_DESC(SparseApplyAdagradV2D))
|
||||
REG_ADPT_DESC(SparseApplyAdagradV2D, kSparseApplyAdagradV2DOpName, ADPT_DESC(SparseApplyAdagradV2D))
|
||||
|
||||
// DataFormatDimMap
|
||||
INPUT_MAP(DataFormatDimMap) = {{1, INPUT_DESC(x)}};
|
||||
|
@ -116,7 +116,7 @@ INPUT_MAP(ApplyAdaMaxD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)},
|
|||
{7, INPUT_DESC(beta2)}, {8, INPUT_DESC(epsilon)}, {9, INPUT_DESC(grad)}};
|
||||
ATTR_MAP(ApplyAdaMaxD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyAdaMaxD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}};
|
||||
REG_ADPT_DESC(ApplyAdaMaxD, kNameApplyAdaMax, ADPT_DESC(ApplyAdaMaxD))
|
||||
REG_ADPT_DESC(ApplyAdaMaxD, kApplyAdaMaxDOpName, ADPT_DESC(ApplyAdaMaxD))
|
||||
|
||||
// ApplyGradientDescent
|
||||
INPUT_MAP(ApplyGradientDescent) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(alpha)}, {3, INPUT_DESC(delta)}};
|
||||
|
@ -184,7 +184,7 @@ ATTR_MAP(SparseApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<b
|
|||
{"l1", ATTR_DESC(l1, AnyTraits<float>())},
|
||||
{"l2", ATTR_DESC(l2, AnyTraits<float>())},
|
||||
{"lr_power", ATTR_DESC(lr_power, AnyTraits<float>())}};
|
||||
OUTPUT_MAP(SparseApplyFtrlD) = {{0, OUTPUT_DESC(var)}};
|
||||
OUTPUT_MAP(SparseApplyFtrlD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(linear)}};
|
||||
REG_ADPT_DESC(SparseApplyFtrlD, kNameSparseApplyFtrlD, ADPT_DESC(SparseApplyFtrlD))
|
||||
|
||||
// SparseApplyFtrl
|
||||
|
@ -253,7 +253,7 @@ INPUT_MAP(ApplyAdaMax) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)},
|
|||
{7, INPUT_DESC(beta2)}, {8, INPUT_DESC(epsilon)}, {9, INPUT_DESC(grad)}};
|
||||
ATTR_MAP(ApplyAdaMax) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyAdaMax) = {{0, OUTPUT_DESC(var)}};
|
||||
REG_ADPT_DESC(ApplyAdaMax, kApplyAdaMaxDOpName, ADPT_DESC(ApplyAdaMax))
|
||||
REG_ADPT_DESC(ApplyAdaMax, kNameApplyAdaMax, ADPT_DESC(ApplyAdaMax))
|
||||
|
||||
// SparseApplyAdagrad
|
||||
INPUT_MAP(SparseApplyAdagrad) = {
|
||||
|
@ -270,7 +270,7 @@ ATTR_INPUT_MAP(SparseApplyAdagradV2) = {{"lr", "lr"}, {"epsilon", "epsilon"}};
|
|||
ATTR_MAP(SparseApplyAdagradV2) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())},
|
||||
{"update_slots", ATTR_DESC(update_slots, AnyTraits<float>())}};
|
||||
OUTPUT_MAP(SparseApplyAdagradV2) = {{0, OUTPUT_DESC(var)}};
|
||||
REG_ADPT_DESC(SparseApplyAdagradV2, kSparseApplyAdagradV2DOpName, ADPT_DESC(SparseApplyAdagradV2))
|
||||
REG_ADPT_DESC(SparseApplyAdagradV2, kNameSparseApplyAdagradV2, ADPT_DESC(SparseApplyAdagradV2))
|
||||
|
||||
// ApplyKerasMomentum
|
||||
INPUT_MAP(ApplyKerasMomentum) = {
|
||||
|
@ -278,7 +278,15 @@ INPUT_MAP(ApplyKerasMomentum) = {
|
|||
ATTR_MAP(ApplyKerasMomentum) = {{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())},
|
||||
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyKerasMomentum) = {{0, OUTPUT_DESC(var)}};
|
||||
REG_ADPT_DESC(ApplyKerasMomentum, kApplyKerasMomentumDOpName, ADPT_DESC(ApplyKerasMomentum))
|
||||
REG_ADPT_DESC(ApplyKerasMomentum, kApplyKerasMomentumOpName, ADPT_DESC(ApplyKerasMomentum))
|
||||
|
||||
// ApplyKerasMomentumD
|
||||
INPUT_MAP(ApplyKerasMomentumD) = {
|
||||
{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}, {5, INPUT_DESC(momentum)}};
|
||||
ATTR_MAP(ApplyKerasMomentumD) = {{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())},
|
||||
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyKerasMomentumD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
|
||||
REG_ADPT_DESC(ApplyKerasMomentumD, kApplyKerasMomentumDOpName, ADPT_DESC(ApplyKerasMomentumD))
|
||||
|
||||
// ApplyAdamWithAmsgrad
|
||||
INPUT_MAP(ApplyAdamWithAmsgrad) = {
|
||||
|
@ -304,14 +312,14 @@ INPUT_MAP(ApplyAddSign) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)},
|
|||
{7, INPUT_DESC(grad)}};
|
||||
ATTR_MAP(ApplyAddSign) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyAddSign) = {{0, OUTPUT_DESC(var)}};
|
||||
REG_ADPT_DESC(ApplyAddSign, kApplyAddSignDOpName, ADPT_DESC(ApplyAddSign))
|
||||
REG_ADPT_DESC(ApplyAddSign, kNameApplyAddSign, ADPT_DESC(ApplyAddSign))
|
||||
|
||||
// ApplyAdagrad
|
||||
INPUT_MAP(ApplyAdagrad) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}};
|
||||
ATTR_MAP(ApplyAdagrad) = {{"update_slots", ATTR_DESC(update_slots, AnyTraits<bool>())},
|
||||
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyAdagrad) = {{0, OUTPUT_DESC(var)}};
|
||||
REG_ADPT_DESC(ApplyAdagrad, kApplyAdagradDOpName, ADPT_DESC(ApplyAdagrad))
|
||||
REG_ADPT_DESC(ApplyAdagrad, kNameApplyAdagrad, ADPT_DESC(ApplyAdagrad))
|
||||
|
||||
// ApplyAdagradV2
|
||||
INPUT_MAP(ApplyAdagradV2) = {
|
||||
|
|
|
@ -110,6 +110,9 @@ DECLARE_OP_USE_OUTPUT(SparseApplyAdagrad)
|
|||
DECLARE_OP_ADAPTER(SparseApplyAdagradV2)
|
||||
DECLARE_OP_USE_OUTPUT(SparseApplyAdagradV2)
|
||||
|
||||
DECLARE_OP_ADAPTER(ApplyKerasMomentumD)
|
||||
DECLARE_OP_USE_OUTPUT(ApplyKerasMomentumD)
|
||||
|
||||
DECLARE_OP_ADAPTER(ApplyKerasMomentum)
|
||||
DECLARE_OP_USE_OUTPUT(ApplyKerasMomentum)
|
||||
|
||||
|
|
|
@ -68,28 +68,13 @@ OUTPUT_MAP(ReduceSum) = {{0, OUTPUT_DESC(y)}};
|
|||
REG_ADPT_DESC(ReduceSum, prim::kPrimReduceSum->name(), ADPT_DESC(ReduceSum))
|
||||
REG_ADPT_DESC(ReduceSumD, prim::kPrimReduceSumD->name(), ADPT_DESC(ReduceSum))
|
||||
|
||||
// ReduceProdD
|
||||
INPUT_MAP(ReduceProdD) = {{1, INPUT_DESC(x)}};
|
||||
INPUT_ATTR_MAP(ReduceProdD) = {
|
||||
{2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
|
||||
ATTR_MAP(ReduceProdD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ReduceProdD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ReduceProd, kNameReduceProd, ADPT_DESC(ReduceProdD))
|
||||
|
||||
// ReduceAllD
|
||||
INPUT_MAP(ReduceAllD) = {{1, INPUT_DESC(x)}};
|
||||
INPUT_ATTR_MAP(ReduceAllD) = {
|
||||
{2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
|
||||
ATTR_MAP(ReduceAllD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ReduceAllD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ReduceAllD, prim::kPrimReduceAll->name(), ADPT_DESC(ReduceAllD))
|
||||
|
||||
// ReduceAll
|
||||
INPUT_MAP(ReduceAll) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
|
||||
ATTR_INPUT_MAP(ReduceAll) = {{"axis", "axes"}};
|
||||
ATTR_MAP(ReduceAll) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ReduceAll) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ReduceAll, prim::kPrimReduceAllD->name(), ADPT_DESC(ReduceAll))
|
||||
REG_ADPT_DESC(ReduceAll, prim::kPrimReduceAll->name(), ADPT_DESC(ReduceAll))
|
||||
REG_ADPT_DESC(ReduceAllD, prim::kPrimReduceAllD->name(), ADPT_DESC(ReduceAll))
|
||||
|
||||
// ReduceMean
|
||||
INPUT_MAP(ReduceMean) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
|
||||
|
@ -99,34 +84,20 @@ OUTPUT_MAP(ReduceMean) = {{0, OUTPUT_DESC(y)}};
|
|||
REG_ADPT_DESC(ReduceMean, prim::kPrimReduceMean->name(), ADPT_DESC(ReduceMean))
|
||||
REG_ADPT_DESC(ReduceMeanD, prim::kPrimReduceMeanD->name(), ADPT_DESC(ReduceMean))
|
||||
|
||||
// ReduceMinD
|
||||
INPUT_MAP(ReduceMinD) = {{1, INPUT_DESC(x)}};
|
||||
INPUT_ATTR_MAP(ReduceMinD) = {
|
||||
{2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
|
||||
ATTR_MAP(ReduceMinD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ReduceMinD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ReduceMinD, prim::kPrimReduceMin->name(), ADPT_DESC(ReduceMinD))
|
||||
|
||||
// ReduceMin
|
||||
INPUT_MAP(ReduceMin) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
|
||||
ATTR_INPUT_MAP(ReduceMin) = {{"axis", "axes"}};
|
||||
ATTR_MAP(ReduceMin) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ReduceMin) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ReduceMin, prim::kPrimReduceMinD->name(), ADPT_DESC(ReduceMin))
|
||||
|
||||
// ReduceMaxD
|
||||
INPUT_MAP(ReduceMaxD) = {{1, INPUT_DESC(x)}};
|
||||
INPUT_ATTR_MAP(ReduceMaxD) = {
|
||||
{2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
|
||||
ATTR_MAP(ReduceMaxD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ReduceMaxD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ReduceMax, prim::kPrimReduceMax->name(), ADPT_DESC(ReduceMaxD))
|
||||
REG_ADPT_DESC(ReduceMin, prim::kPrimReduceMin->name(), ADPT_DESC(ReduceMin))
|
||||
REG_ADPT_DESC(ReduceMinD, prim::kPrimReduceMinD->name(), ADPT_DESC(ReduceMin))
|
||||
|
||||
// ReduceMax
|
||||
INPUT_MAP(ReduceMax) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
|
||||
ATTR_INPUT_MAP(ReduceMax) = {{"axis", "axes"}};
|
||||
ATTR_MAP(ReduceMax) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ReduceMax) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ReduceMax, prim::kPrimReduceMax->name(), ADPT_DESC(ReduceMax))
|
||||
REG_ADPT_DESC(ReduceMaxD, prim::kPrimReduceMaxD->name(), ADPT_DESC(ReduceMax))
|
||||
|
||||
// ReduceStd
|
||||
|
@ -142,6 +113,7 @@ INPUT_MAP(ReduceProd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
|
|||
ATTR_INPUT_MAP(ReduceProd) = {{"axis", "axes"}};
|
||||
ATTR_MAP(ReduceProd) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ReduceProd) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ReduceProd, prim::kPrimReduceProd->name(), ADPT_DESC(ReduceProd))
|
||||
REG_ADPT_DESC(DynamicReduceProd, kNameDynamicReduceProd, ADPT_DESC(ReduceProd))
|
||||
REG_ADPT_DESC(ReduceProdD, prim::kPrimReduceProdD->name(), ADPT_DESC(ReduceProd))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -25,24 +25,12 @@ namespace mindspore::transform {
|
|||
DECLARE_OP_ADAPTER(ReduceMean)
|
||||
DECLARE_OP_USE_OUTPUT(ReduceMean)
|
||||
|
||||
DECLARE_OP_ADAPTER(ReduceMinD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(ReduceMinD)
|
||||
DECLARE_OP_USE_OUTPUT(ReduceMinD)
|
||||
|
||||
DECLARE_OP_ADAPTER(ReduceMin)
|
||||
DECLARE_OP_USE_OUTPUT(ReduceMin)
|
||||
|
||||
DECLARE_OP_ADAPTER(ReduceMaxD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(ReduceMaxD)
|
||||
DECLARE_OP_USE_OUTPUT(ReduceMaxD)
|
||||
|
||||
DECLARE_OP_ADAPTER(ReduceMax)
|
||||
DECLARE_OP_USE_OUTPUT(ReduceMax)
|
||||
|
||||
DECLARE_OP_ADAPTER(ReduceAllD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(ReduceAllD)
|
||||
DECLARE_OP_USE_OUTPUT(ReduceAllD)
|
||||
|
||||
DECLARE_OP_ADAPTER(ReduceAll)
|
||||
DECLARE_OP_USE_OUTPUT(ReduceAll)
|
||||
|
||||
|
@ -64,10 +52,6 @@ DECLARE_OP_USE_OUTPUT(ReduceSum)
|
|||
DECLARE_OP_ADAPTER(ReduceAny)
|
||||
DECLARE_OP_USE_OUTPUT(ReduceAny)
|
||||
|
||||
DECLARE_OP_ADAPTER(ReduceProdD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(ReduceProdD)
|
||||
DECLARE_OP_USE_OUTPUT(ReduceProdD)
|
||||
|
||||
DECLARE_OP_ADAPTER(ReduceStd)
|
||||
DECLARE_OP_USE_OUTPUT(ReduceStd)
|
||||
|
||||
|
|
|
@ -205,13 +205,6 @@ ATTR_MAP(UnsortedSegmentProd) = EMPTY_ATTR_MAP;
|
|||
OUTPUT_MAP(UnsortedSegmentProd) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(UnsortedSegmentProd, kNameUnsortedSegmentProd, ADPT_DESC(UnsortedSegmentProd))
|
||||
|
||||
// UnsortedSegmentMaxD
|
||||
INPUT_MAP(UnsortedSegmentMaxD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}};
|
||||
INPUT_ATTR_MAP(UnsortedSegmentMaxD) = {{3, ATTR_DESC(num_segments, AnyTraits<int64_t>())}};
|
||||
ATTR_MAP(UnsortedSegmentMaxD) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(UnsortedSegmentMaxD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(UnsortedSegmentMaxD, kNameUnsortedSegmentMax, ADPT_DESC(UnsortedSegmentMaxD))
|
||||
|
||||
// UnsortedSegmentMin
|
||||
INPUT_MAP(UnsortedSegmentMin) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}, {3, INPUT_DESC(num_segments)}};
|
||||
ATTR_MAP(UnsortedSegmentMin) = EMPTY_ATTR_MAP;
|
||||
|
@ -296,5 +289,6 @@ INPUT_MAP(UnsortedSegmentMax) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)
|
|||
ATTR_INPUT_MAP(UnsortedSegmentMax) = {{"num_segments", "num_segments"}};
|
||||
ATTR_MAP(UnsortedSegmentMax) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(UnsortedSegmentMax) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(UnsortedSegmentMax, kUnsortedSegmentMaxDOpName, ADPT_DESC(UnsortedSegmentMax))
|
||||
REG_ADPT_DESC(UnsortedSegmentMax, kUnsortedSegmentMaxOpName, ADPT_DESC(UnsortedSegmentMax))
|
||||
REG_ADPT_DESC(UnsortedSegmentMaxD, kUnsortedSegmentMaxDOpName, ADPT_DESC(UnsortedSegmentMax))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -64,10 +64,6 @@ DECLARE_OP_USE_OUTPUT(UnsortedSegmentSum)
|
|||
DECLARE_OP_ADAPTER(UnsortedSegmentProd)
|
||||
DECLARE_OP_USE_OUTPUT(UnsortedSegmentProd)
|
||||
|
||||
DECLARE_OP_ADAPTER(UnsortedSegmentMaxD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentMaxD)
|
||||
DECLARE_OP_USE_OUTPUT(UnsortedSegmentMaxD)
|
||||
|
||||
DECLARE_OP_ADAPTER(UnsortedSegmentMax)
|
||||
DECLARE_OP_USE_OUTPUT(UnsortedSegmentMax)
|
||||
|
||||
|
|
|
@ -124,7 +124,7 @@ REG_ADPT_DESC(SpaceToBatch, kSpaceToBatchDOpName, ADPT_DESC(SpaceToBatch))
|
|||
// ExtractVolumePatches
|
||||
INPUT_MAP(ExtractVolumePatches) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(ExtractVolumePatches) = {
|
||||
{"ksizes", ATTR_DESC(ksizes, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"kernel_size", ATTR_DESC(ksizes, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"strides", ATTR_DESC(strides, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"padding", ATTR_DESC(padding, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(ExtractVolumePatches) = {{0, OUTPUT_DESC(y)}};
|
||||
|
|
|
@ -54,21 +54,10 @@ void DynamicRNNShapeCheck(const PrimitivePtr &primitive, const std::vector<Abstr
|
|||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx0]->BuildShape())[kShape];
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx1]->BuildShape())[kShape];
|
||||
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx2]->BuildShape())[kShape];
|
||||
auto seq_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx3]->BuildShape())[kShape];
|
||||
auto h_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx4]->BuildShape())[kShape];
|
||||
auto c_shape_ptr = input_args[kDynRnnIdx5]->BuildShape();
|
||||
auto c_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx5]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_shape", SizeToLong(x_shape.size()), kEqual, kDynamicRnnShapeX, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("w_shape", SizeToLong(w_shape.size()), kEqual, kDynamicRnnShapeW, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("b_shape", SizeToLong(b_shape.size()), kEqual, kDynamicRnnShapeB, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape", SizeToLong(h_shape.size()), kEqual, kDynamicRnnShapeH, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("c_shape", SizeToLong(c_shape.size()), kEqual, kDynamicRnnShapeC, op_name);
|
||||
int64_t batch_size = x_shape[kDynRnnIdx1];
|
||||
int64_t input_size = x_shape[kDynRnnIdx2];
|
||||
if (seq_shape.size() != 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', input 'seq' shape must be 0, but got " << seq_shape.size()
|
||||
<< ".";
|
||||
}
|
||||
int64_t hidden_size = w_shape[w_shape.size() - 1] / kDynRnnNum4;
|
||||
if (w_shape[w_shape.size() - 1] % kDynRnnNum4 != 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', w_shape[-1] should multiple of 4, now is "
|
||||
|
@ -83,11 +72,30 @@ void DynamicRNNShapeCheck(const PrimitivePtr &primitive, const std::vector<Abstr
|
|||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', b_shape[0] should equal to w_shape[1], but gets "
|
||||
<< b_shape[0] << ".";
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape[0]", h_shape[kDynRnnIdx0], kEqual, (int64_t)1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape[1]", h_shape[kDynRnnIdx1], kEqual, (int64_t)batch_size, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape[2]", h_shape[kDynRnnIdx2], kEqual, (int64_t)hidden_size, op_name);
|
||||
const std::map<std::string, BaseShapePtr> shapes = {{"c_shape", c_shape_ptr}};
|
||||
(void)CheckAndConvertUtils::CheckTensorShapeSame(shapes, h_shape, op_name);
|
||||
|
||||
if (input_args.size() > kDynRnnIdx3) {
|
||||
auto seq_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx3]->BuildShape())[kShape];
|
||||
if (seq_shape.size() != 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', input 'seq' shape must be 0, but got " << seq_shape.size()
|
||||
<< ".";
|
||||
}
|
||||
}
|
||||
if (input_args.size() > kDynRnnIdx4) {
|
||||
int64_t batch_size = x_shape[kDynRnnIdx1];
|
||||
auto h_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx4]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape", SizeToLong(h_shape.size()), kEqual, kDynamicRnnShapeH, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape[0]", h_shape[kDynRnnIdx0], kEqual, (int64_t)1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape[1]", h_shape[kDynRnnIdx1], kEqual, (int64_t)batch_size, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape[2]", h_shape[kDynRnnIdx2], kEqual, (int64_t)hidden_size, op_name);
|
||||
if (input_args.size() > kDynRnnIdx5) {
|
||||
auto c_shape_ptr = input_args[kDynRnnIdx5]->BuildShape();
|
||||
auto c_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx5]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("c_shape", SizeToLong(c_shape.size()), kEqual, kDynamicRnnShapeC,
|
||||
op_name);
|
||||
const std::map<std::string, BaseShapePtr> shapes = {{"c_shape", c_shape_ptr}};
|
||||
(void)CheckAndConvertUtils::CheckTensorShapeSame(shapes, h_shape, op_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abstract::TupleShapePtr DynamicRNNInferShape(const PrimitivePtr &primitive,
|
||||
|
|
Loading…
Reference in New Issue