!48662 Fix ACL Set IO Size and infer bug.

Merge pull request !48662 from linqingke/acl
This commit is contained in:
i-robot 2023-02-16 02:31:49 +00:00 committed by Gitee
commit 7099307702
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 58 additions and 90 deletions

View File

@ -104,10 +104,13 @@ if(MODE_ASCEND_ALL)
find_library(OPTILING optiling ${ASCEND_CANN_OPP_AARCH64_PATH} ${ASCEND_TOOLKIT_OPP_AARCH64_PATH})
endif()
find_library(ACL_OP_COMPILER acl_op_compiler ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(CANN_KB cann_kb ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(COMPRESS compress ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(OPSKERNEL opskernel ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
target_link_libraries(mindspore_ascend PRIVATE ${RUNTIME_LIB} ${TSDCLIENT} ${DATATRANSFER} ${ERROR_MANAGER}
-Wl,--no-as-needed ${OPTILING} ${PLATFORM} ${ACL} ${ACL_TDT_CHANNEL} ${OPT_FEATURE} ${PROFILING}
${ACL_OP_COMPILER})
${ACL_OP_COMPILER} ${CANN_KB} ${COMPRESS} ${OPSKERNEL})
target_link_libraries(mindspore_ascend PRIVATE ${adump_server})
endif()

View File

@ -67,7 +67,7 @@ INPUT_MAP(ApplyAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, IN
ATTR_MAP(ApplyAdagradD) = {{"update_slots", ATTR_DESC(update_slots, AnyTraits<bool>())},
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
REG_ADPT_DESC(ApplyAdagradD, kNameApplyAdagrad, ADPT_DESC(ApplyAdagradD))
REG_ADPT_DESC(ApplyAdagradD, kApplyAdagradDOpName, ADPT_DESC(ApplyAdagradD))
// ApplyAdagradV2D
INPUT_MAP(ApplyAdagradV2D) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}};
@ -83,7 +83,7 @@ INPUT_MAP(ApplyAddSignD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)},
{7, INPUT_DESC(grad)}};
ATTR_MAP(ApplyAddSignD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAddSignD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}};
REG_ADPT_DESC(ApplyAddSignD, kNameApplyAddSign, ADPT_DESC(ApplyAddSignD))
REG_ADPT_DESC(ApplyAddSignD, kApplyAddSignDOpName, ADPT_DESC(ApplyAddSignD))
// SparseApplyAdagradV2D
INPUT_MAP(SparseApplyAdagradV2D) = {
@ -93,7 +93,7 @@ ATTR_MAP(SparseApplyAdagradV2D) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())},
{"update_slots", ATTR_DESC(update_slots, AnyTraits<bool>())},
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(SparseApplyAdagradV2D) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
REG_ADPT_DESC(SparseApplyAdagradV2D, kNameSparseApplyAdagradV2, ADPT_DESC(SparseApplyAdagradV2D))
REG_ADPT_DESC(SparseApplyAdagradV2D, kSparseApplyAdagradV2DOpName, ADPT_DESC(SparseApplyAdagradV2D))
// DataFormatDimMap
INPUT_MAP(DataFormatDimMap) = {{1, INPUT_DESC(x)}};
@ -116,7 +116,7 @@ INPUT_MAP(ApplyAdaMaxD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)},
{7, INPUT_DESC(beta2)}, {8, INPUT_DESC(epsilon)}, {9, INPUT_DESC(grad)}};
ATTR_MAP(ApplyAdaMaxD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAdaMaxD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}};
REG_ADPT_DESC(ApplyAdaMaxD, kNameApplyAdaMax, ADPT_DESC(ApplyAdaMaxD))
REG_ADPT_DESC(ApplyAdaMaxD, kApplyAdaMaxDOpName, ADPT_DESC(ApplyAdaMaxD))
// ApplyGradientDescent
INPUT_MAP(ApplyGradientDescent) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(alpha)}, {3, INPUT_DESC(delta)}};
@ -184,7 +184,7 @@ ATTR_MAP(SparseApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<b
{"l1", ATTR_DESC(l1, AnyTraits<float>())},
{"l2", ATTR_DESC(l2, AnyTraits<float>())},
{"lr_power", ATTR_DESC(lr_power, AnyTraits<float>())}};
OUTPUT_MAP(SparseApplyFtrlD) = {{0, OUTPUT_DESC(var)}};
OUTPUT_MAP(SparseApplyFtrlD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(linear)}};
REG_ADPT_DESC(SparseApplyFtrlD, kNameSparseApplyFtrlD, ADPT_DESC(SparseApplyFtrlD))
// SparseApplyFtrl
@ -253,7 +253,7 @@ INPUT_MAP(ApplyAdaMax) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)},
{7, INPUT_DESC(beta2)}, {8, INPUT_DESC(epsilon)}, {9, INPUT_DESC(grad)}};
ATTR_MAP(ApplyAdaMax) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAdaMax) = {{0, OUTPUT_DESC(var)}};
REG_ADPT_DESC(ApplyAdaMax, kApplyAdaMaxDOpName, ADPT_DESC(ApplyAdaMax))
REG_ADPT_DESC(ApplyAdaMax, kNameApplyAdaMax, ADPT_DESC(ApplyAdaMax))
// SparseApplyAdagrad
INPUT_MAP(SparseApplyAdagrad) = {
@ -270,7 +270,7 @@ ATTR_INPUT_MAP(SparseApplyAdagradV2) = {{"lr", "lr"}, {"epsilon", "epsilon"}};
ATTR_MAP(SparseApplyAdagradV2) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())},
{"update_slots", ATTR_DESC(update_slots, AnyTraits<float>())}};
OUTPUT_MAP(SparseApplyAdagradV2) = {{0, OUTPUT_DESC(var)}};
REG_ADPT_DESC(SparseApplyAdagradV2, kSparseApplyAdagradV2DOpName, ADPT_DESC(SparseApplyAdagradV2))
REG_ADPT_DESC(SparseApplyAdagradV2, kNameSparseApplyAdagradV2, ADPT_DESC(SparseApplyAdagradV2))
// ApplyKerasMomentum
INPUT_MAP(ApplyKerasMomentum) = {
@ -278,7 +278,15 @@ INPUT_MAP(ApplyKerasMomentum) = {
ATTR_MAP(ApplyKerasMomentum) = {{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())},
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyKerasMomentum) = {{0, OUTPUT_DESC(var)}};
REG_ADPT_DESC(ApplyKerasMomentum, kApplyKerasMomentumDOpName, ADPT_DESC(ApplyKerasMomentum))
REG_ADPT_DESC(ApplyKerasMomentum, kApplyKerasMomentumOpName, ADPT_DESC(ApplyKerasMomentum))
// ApplyKerasMomentumD
INPUT_MAP(ApplyKerasMomentumD) = {
{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}, {5, INPUT_DESC(momentum)}};
ATTR_MAP(ApplyKerasMomentumD) = {{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())},
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyKerasMomentumD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
REG_ADPT_DESC(ApplyKerasMomentumD, kApplyKerasMomentumDOpName, ADPT_DESC(ApplyKerasMomentumD))
// ApplyAdamWithAmsgrad
INPUT_MAP(ApplyAdamWithAmsgrad) = {
@ -304,14 +312,14 @@ INPUT_MAP(ApplyAddSign) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)},
{7, INPUT_DESC(grad)}};
ATTR_MAP(ApplyAddSign) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAddSign) = {{0, OUTPUT_DESC(var)}};
REG_ADPT_DESC(ApplyAddSign, kApplyAddSignDOpName, ADPT_DESC(ApplyAddSign))
REG_ADPT_DESC(ApplyAddSign, kNameApplyAddSign, ADPT_DESC(ApplyAddSign))
// ApplyAdagrad
INPUT_MAP(ApplyAdagrad) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}};
ATTR_MAP(ApplyAdagrad) = {{"update_slots", ATTR_DESC(update_slots, AnyTraits<bool>())},
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAdagrad) = {{0, OUTPUT_DESC(var)}};
REG_ADPT_DESC(ApplyAdagrad, kApplyAdagradDOpName, ADPT_DESC(ApplyAdagrad))
REG_ADPT_DESC(ApplyAdagrad, kNameApplyAdagrad, ADPT_DESC(ApplyAdagrad))
// ApplyAdagradV2
INPUT_MAP(ApplyAdagradV2) = {

View File

@ -110,6 +110,9 @@ DECLARE_OP_USE_OUTPUT(SparseApplyAdagrad)
DECLARE_OP_ADAPTER(SparseApplyAdagradV2)
DECLARE_OP_USE_OUTPUT(SparseApplyAdagradV2)
DECLARE_OP_ADAPTER(ApplyKerasMomentumD)
DECLARE_OP_USE_OUTPUT(ApplyKerasMomentumD)
DECLARE_OP_ADAPTER(ApplyKerasMomentum)
DECLARE_OP_USE_OUTPUT(ApplyKerasMomentum)

View File

@ -68,28 +68,13 @@ OUTPUT_MAP(ReduceSum) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ReduceSum, prim::kPrimReduceSum->name(), ADPT_DESC(ReduceSum))
REG_ADPT_DESC(ReduceSumD, prim::kPrimReduceSumD->name(), ADPT_DESC(ReduceSum))
// ReduceProdD
INPUT_MAP(ReduceProdD) = {{1, INPUT_DESC(x)}};
INPUT_ATTR_MAP(ReduceProdD) = {
{2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(ReduceProdD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
OUTPUT_MAP(ReduceProdD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ReduceProd, kNameReduceProd, ADPT_DESC(ReduceProdD))
// ReduceAllD
INPUT_MAP(ReduceAllD) = {{1, INPUT_DESC(x)}};
INPUT_ATTR_MAP(ReduceAllD) = {
{2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(ReduceAllD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
OUTPUT_MAP(ReduceAllD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ReduceAllD, prim::kPrimReduceAll->name(), ADPT_DESC(ReduceAllD))
// ReduceAll
INPUT_MAP(ReduceAll) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
ATTR_INPUT_MAP(ReduceAll) = {{"axis", "axes"}};
ATTR_MAP(ReduceAll) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
OUTPUT_MAP(ReduceAll) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ReduceAll, prim::kPrimReduceAllD->name(), ADPT_DESC(ReduceAll))
REG_ADPT_DESC(ReduceAll, prim::kPrimReduceAll->name(), ADPT_DESC(ReduceAll))
REG_ADPT_DESC(ReduceAllD, prim::kPrimReduceAllD->name(), ADPT_DESC(ReduceAll))
// ReduceMean
INPUT_MAP(ReduceMean) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
@ -99,34 +84,20 @@ OUTPUT_MAP(ReduceMean) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ReduceMean, prim::kPrimReduceMean->name(), ADPT_DESC(ReduceMean))
REG_ADPT_DESC(ReduceMeanD, prim::kPrimReduceMeanD->name(), ADPT_DESC(ReduceMean))
// ReduceMinD
INPUT_MAP(ReduceMinD) = {{1, INPUT_DESC(x)}};
INPUT_ATTR_MAP(ReduceMinD) = {
{2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(ReduceMinD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
OUTPUT_MAP(ReduceMinD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ReduceMinD, prim::kPrimReduceMin->name(), ADPT_DESC(ReduceMinD))
// ReduceMin
INPUT_MAP(ReduceMin) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
ATTR_INPUT_MAP(ReduceMin) = {{"axis", "axes"}};
ATTR_MAP(ReduceMin) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
OUTPUT_MAP(ReduceMin) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ReduceMin, prim::kPrimReduceMinD->name(), ADPT_DESC(ReduceMin))
// ReduceMaxD
INPUT_MAP(ReduceMaxD) = {{1, INPUT_DESC(x)}};
INPUT_ATTR_MAP(ReduceMaxD) = {
{2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(ReduceMaxD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
OUTPUT_MAP(ReduceMaxD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ReduceMax, prim::kPrimReduceMax->name(), ADPT_DESC(ReduceMaxD))
REG_ADPT_DESC(ReduceMin, prim::kPrimReduceMin->name(), ADPT_DESC(ReduceMin))
REG_ADPT_DESC(ReduceMinD, prim::kPrimReduceMinD->name(), ADPT_DESC(ReduceMin))
// ReduceMax
INPUT_MAP(ReduceMax) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
ATTR_INPUT_MAP(ReduceMax) = {{"axis", "axes"}};
ATTR_MAP(ReduceMax) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
OUTPUT_MAP(ReduceMax) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ReduceMax, prim::kPrimReduceMax->name(), ADPT_DESC(ReduceMax))
REG_ADPT_DESC(ReduceMaxD, prim::kPrimReduceMaxD->name(), ADPT_DESC(ReduceMax))
// ReduceStd
@ -142,6 +113,7 @@ INPUT_MAP(ReduceProd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
ATTR_INPUT_MAP(ReduceProd) = {{"axis", "axes"}};
ATTR_MAP(ReduceProd) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
OUTPUT_MAP(ReduceProd) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ReduceProd, prim::kPrimReduceProd->name(), ADPT_DESC(ReduceProd))
REG_ADPT_DESC(DynamicReduceProd, kNameDynamicReduceProd, ADPT_DESC(ReduceProd))
REG_ADPT_DESC(ReduceProdD, prim::kPrimReduceProdD->name(), ADPT_DESC(ReduceProd))
} // namespace mindspore::transform

View File

@ -25,24 +25,12 @@ namespace mindspore::transform {
DECLARE_OP_ADAPTER(ReduceMean)
DECLARE_OP_USE_OUTPUT(ReduceMean)
DECLARE_OP_ADAPTER(ReduceMinD)
DECLARE_OP_USE_INPUT_ATTR(ReduceMinD)
DECLARE_OP_USE_OUTPUT(ReduceMinD)
DECLARE_OP_ADAPTER(ReduceMin)
DECLARE_OP_USE_OUTPUT(ReduceMin)
DECLARE_OP_ADAPTER(ReduceMaxD)
DECLARE_OP_USE_INPUT_ATTR(ReduceMaxD)
DECLARE_OP_USE_OUTPUT(ReduceMaxD)
DECLARE_OP_ADAPTER(ReduceMax)
DECLARE_OP_USE_OUTPUT(ReduceMax)
DECLARE_OP_ADAPTER(ReduceAllD)
DECLARE_OP_USE_INPUT_ATTR(ReduceAllD)
DECLARE_OP_USE_OUTPUT(ReduceAllD)
DECLARE_OP_ADAPTER(ReduceAll)
DECLARE_OP_USE_OUTPUT(ReduceAll)
@ -64,10 +52,6 @@ DECLARE_OP_USE_OUTPUT(ReduceSum)
DECLARE_OP_ADAPTER(ReduceAny)
DECLARE_OP_USE_OUTPUT(ReduceAny)
DECLARE_OP_ADAPTER(ReduceProdD)
DECLARE_OP_USE_INPUT_ATTR(ReduceProdD)
DECLARE_OP_USE_OUTPUT(ReduceProdD)
DECLARE_OP_ADAPTER(ReduceStd)
DECLARE_OP_USE_OUTPUT(ReduceStd)

View File

@ -205,13 +205,6 @@ ATTR_MAP(UnsortedSegmentProd) = EMPTY_ATTR_MAP;
OUTPUT_MAP(UnsortedSegmentProd) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(UnsortedSegmentProd, kNameUnsortedSegmentProd, ADPT_DESC(UnsortedSegmentProd))
// UnsortedSegmentMaxD
INPUT_MAP(UnsortedSegmentMaxD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}};
INPUT_ATTR_MAP(UnsortedSegmentMaxD) = {{3, ATTR_DESC(num_segments, AnyTraits<int64_t>())}};
ATTR_MAP(UnsortedSegmentMaxD) = EMPTY_ATTR_MAP;
OUTPUT_MAP(UnsortedSegmentMaxD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(UnsortedSegmentMaxD, kNameUnsortedSegmentMax, ADPT_DESC(UnsortedSegmentMaxD))
// UnsortedSegmentMin
INPUT_MAP(UnsortedSegmentMin) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}, {3, INPUT_DESC(num_segments)}};
ATTR_MAP(UnsortedSegmentMin) = EMPTY_ATTR_MAP;
@ -296,5 +289,6 @@ INPUT_MAP(UnsortedSegmentMax) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)
ATTR_INPUT_MAP(UnsortedSegmentMax) = {{"num_segments", "num_segments"}};
ATTR_MAP(UnsortedSegmentMax) = EMPTY_ATTR_MAP;
OUTPUT_MAP(UnsortedSegmentMax) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(UnsortedSegmentMax, kUnsortedSegmentMaxDOpName, ADPT_DESC(UnsortedSegmentMax))
REG_ADPT_DESC(UnsortedSegmentMax, kUnsortedSegmentMaxOpName, ADPT_DESC(UnsortedSegmentMax))
REG_ADPT_DESC(UnsortedSegmentMaxD, kUnsortedSegmentMaxDOpName, ADPT_DESC(UnsortedSegmentMax))
} // namespace mindspore::transform

View File

@ -64,10 +64,6 @@ DECLARE_OP_USE_OUTPUT(UnsortedSegmentSum)
DECLARE_OP_ADAPTER(UnsortedSegmentProd)
DECLARE_OP_USE_OUTPUT(UnsortedSegmentProd)
DECLARE_OP_ADAPTER(UnsortedSegmentMaxD)
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentMaxD)
DECLARE_OP_USE_OUTPUT(UnsortedSegmentMaxD)
DECLARE_OP_ADAPTER(UnsortedSegmentMax)
DECLARE_OP_USE_OUTPUT(UnsortedSegmentMax)

View File

@ -124,7 +124,7 @@ REG_ADPT_DESC(SpaceToBatch, kSpaceToBatchDOpName, ADPT_DESC(SpaceToBatch))
// ExtractVolumePatches
INPUT_MAP(ExtractVolumePatches) = {{1, INPUT_DESC(x)}};
ATTR_MAP(ExtractVolumePatches) = {
{"ksizes", ATTR_DESC(ksizes, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
{"kernel_size", ATTR_DESC(ksizes, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
{"strides", ATTR_DESC(strides, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
{"padding", ATTR_DESC(padding, AnyTraits<std::string>())}};
OUTPUT_MAP(ExtractVolumePatches) = {{0, OUTPUT_DESC(y)}};

View File

@ -54,21 +54,10 @@ void DynamicRNNShapeCheck(const PrimitivePtr &primitive, const std::vector<Abstr
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx0]->BuildShape())[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx1]->BuildShape())[kShape];
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx2]->BuildShape())[kShape];
auto seq_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx3]->BuildShape())[kShape];
auto h_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx4]->BuildShape())[kShape];
auto c_shape_ptr = input_args[kDynRnnIdx5]->BuildShape();
auto c_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx5]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("x_shape", SizeToLong(x_shape.size()), kEqual, kDynamicRnnShapeX, op_name);
(void)CheckAndConvertUtils::CheckInteger("w_shape", SizeToLong(w_shape.size()), kEqual, kDynamicRnnShapeW, op_name);
(void)CheckAndConvertUtils::CheckInteger("b_shape", SizeToLong(b_shape.size()), kEqual, kDynamicRnnShapeB, op_name);
(void)CheckAndConvertUtils::CheckInteger("h_shape", SizeToLong(h_shape.size()), kEqual, kDynamicRnnShapeH, op_name);
(void)CheckAndConvertUtils::CheckInteger("c_shape", SizeToLong(c_shape.size()), kEqual, kDynamicRnnShapeC, op_name);
int64_t batch_size = x_shape[kDynRnnIdx1];
int64_t input_size = x_shape[kDynRnnIdx2];
if (seq_shape.size() != 0) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', input 'seq' shape must be 0, but got " << seq_shape.size()
<< ".";
}
int64_t hidden_size = w_shape[w_shape.size() - 1] / kDynRnnNum4;
if (w_shape[w_shape.size() - 1] % kDynRnnNum4 != 0) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', w_shape[-1] should multiple of 4, now is "
@ -83,11 +72,30 @@ void DynamicRNNShapeCheck(const PrimitivePtr &primitive, const std::vector<Abstr
MS_EXCEPTION(ValueError) << "For '" << op_name << "', b_shape[0] should equal to w_shape[1], but gets "
<< b_shape[0] << ".";
}
(void)CheckAndConvertUtils::CheckInteger("h_shape[0]", h_shape[kDynRnnIdx0], kEqual, (int64_t)1, op_name);
(void)CheckAndConvertUtils::CheckInteger("h_shape[1]", h_shape[kDynRnnIdx1], kEqual, (int64_t)batch_size, op_name);
(void)CheckAndConvertUtils::CheckInteger("h_shape[2]", h_shape[kDynRnnIdx2], kEqual, (int64_t)hidden_size, op_name);
const std::map<std::string, BaseShapePtr> shapes = {{"c_shape", c_shape_ptr}};
(void)CheckAndConvertUtils::CheckTensorShapeSame(shapes, h_shape, op_name);
if (input_args.size() > kDynRnnIdx3) {
auto seq_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx3]->BuildShape())[kShape];
if (seq_shape.size() != 0) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', input 'seq' shape must be 0, but got " << seq_shape.size()
<< ".";
}
}
if (input_args.size() > kDynRnnIdx4) {
int64_t batch_size = x_shape[kDynRnnIdx1];
auto h_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx4]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("h_shape", SizeToLong(h_shape.size()), kEqual, kDynamicRnnShapeH, op_name);
(void)CheckAndConvertUtils::CheckInteger("h_shape[0]", h_shape[kDynRnnIdx0], kEqual, (int64_t)1, op_name);
(void)CheckAndConvertUtils::CheckInteger("h_shape[1]", h_shape[kDynRnnIdx1], kEqual, (int64_t)batch_size, op_name);
(void)CheckAndConvertUtils::CheckInteger("h_shape[2]", h_shape[kDynRnnIdx2], kEqual, (int64_t)hidden_size, op_name);
if (input_args.size() > kDynRnnIdx5) {
auto c_shape_ptr = input_args[kDynRnnIdx5]->BuildShape();
auto c_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDynRnnIdx5]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("c_shape", SizeToLong(c_shape.size()), kEqual, kDynamicRnnShapeC,
op_name);
const std::map<std::string, BaseShapePtr> shapes = {{"c_shape", c_shape_ptr}};
(void)CheckAndConvertUtils::CheckTensorShapeSame(shapes, h_shape, op_name);
}
}
}
abstract::TupleShapePtr DynamicRNNInferShape(const PrimitivePtr &primitive,