diff --git a/cmake/dependency_graphengine.cmake b/cmake/dependency_graphengine.cmake index 91a471d1f26..8e1faa92c6b 100644 --- a/cmake/dependency_graphengine.cmake +++ b/cmake/dependency_graphengine.cmake @@ -15,6 +15,7 @@ include(${GE_SOURCE_DIR}/cmake/external_libs/securec.cmake) if (NOT ENABLE_D) set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}) find_library(slog libslog.so ${GE_PREBUILD_PATH}) + find_library(error_manager liberror_manager.so ${GE_PREBUILD_PATH}) elseif (DEFINED ENV{D_LINK_PATH}) set(GE_LIB_PATH $ENV{D_LINK_PATH}) set(GE_SYS_ARCH "") diff --git a/cmake/package.cmake b/cmake/package.cmake index 42821cf41dd..2fde01af4f2 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -156,6 +156,7 @@ if (NOT ENABLE_GE) set(ASCEND_PATH /usr/local/Ascend) endif () set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common) + set(ASCEND_FWK_PATH ${ASCEND_PATH}/fwkacllib/lib64) install( FILES @@ -164,6 +165,7 @@ if (NOT ENABLE_GE) ${CMAKE_BINARY_DIR}/graphengine/src/ge/ge_runtime/libge_runtime.so ${ASCEND_DRIVER_PATH}/libslog.so ${ASCEND_DRIVER_PATH}/libc_sec.so + ${ASCEND_FWK_PATH}/liberror_manager.so DESTINATION ${INSTALL_LIB_DIR} COMPONENT mindspore ) @@ -172,6 +174,7 @@ if (NOT ENABLE_GE) FILES ${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so ${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libslog.so + ${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/liberror_manager.so ${CMAKE_SOURCE_DIR}/build/graphengine/libc_sec.so DESTINATION ${INSTALL_LIB_DIR} COMPONENT mindspore diff --git a/graphengine b/graphengine index 8891f0546c4..4084909d62c 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit 8891f0546c4a250095ff68e1262f58772b938fd9 +Subproject commit 4084909d62c159da6ba316f61ad3d02a4857b34b diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/common.py b/mindspore/_extends/parallel_compile/tbe_compiler/common.py index 3d55cf60a2c..7287bace950 100644 --- a/mindspore/_extends/parallel_compile/tbe_compiler/common.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/common.py @@ -40,7 +40,7 @@ def get_ddk_version(): with open(backup_ddk_info_file, "r") as fp: ddk_version = json.load(fp)["VERSION"] else: - ddk_version = "1.60.T17.B830" + ddk_version = "Ascend910" return ddk_version diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 3f9965c0429..9dc1502aa5c 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -192,8 +192,9 @@ if (ENABLE_D) find_library(CCE_LIB cce ${ASCEND_RUNTIME_PATH}) find_library(RUNTIME_LIB runtime ${ASCEND_RUNTIME_PATH}) find_library(TSDCLIENT tsdclient HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH}) + find_library(DATATRANSFER datatransfer HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH}) find_library(PROFILING msprof ${ASCEND_DRIVER_PATH}) - target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${PROFILING} ${HCCL} ${TSDCLIENT}) + target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${PROFILING} ${HCCL} ${DATATRANSFER}) endif() # link protobuf diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc index 5ef5d50e9c7..9951321f5ec 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -292,7 +292,6 @@ bool TbeKernelSelect::TbeCheckSupported( parallel::TOPK, parallel::IN_TOPK, parallel::PACK, - parallel::GATHER_ND, parallel::UNSORTEF_SEGMENT_MIND, parallel::UNSORTEF_SEGMENT_PRODD, parallel::CAST}; diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index ff864401b13..48ce87629ca 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -23,6 +23,7 @@ #include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h" #include "pre_activate/ascend/ir_fission/batch_norm_bert_fission.h" #include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h" +#include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h" #include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h" #include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h" #include "pre_activate/pass/communication_op_fusion.h" @@ -149,6 +150,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); } } // namespace @@ -290,6 +292,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); optimizer->AddPassManager(ir_fusion_pm); (void)optimizer->Optimize(kernel_graph); diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc index f909dae9e41..3241684c621 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc @@ -94,7 +94,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP origin_pair = FindRefOriginNode(input_node); MS_EXCEPTION_IF_NULL(origin_pair.first); if (!origin_pair.first->isa()) { - MS_LOG(EXCEPTION) << "ref op origin node is not parameter"; + MS_LOG(WARNING) << "ref op origin node is not parameter"; } MS_LOG(DEBUG) << "DealRefTransAndCast the node input index " << input_index << ", find origin op is " << origin_pair.first->DebugString() << ", index is " << origin_pair.second; diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.cc new file mode 100644 index 00000000000..6e6cea5ae55 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h" +#include +#include +#include "session/anf_runtime_algorithm.h" +#include "pre_activate/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreateTensorMove(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(tensor_scatter_update); + std::vector inputs = {NewValueNode(std::make_shared(kTensorMoveOpName)), + tensor_scatter_update->input(1)}; + auto tensor_move = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(tensor_move); + tensor_move->set_scope(tensor_scatter_update->scope()); + tensor_move->set_abstract(tensor_scatter_update->abstract()); + AnfAlgo::SetNodeAttr(kAttrUseLocking, MakeValue(false), tensor_move); + return tensor_move; +} + +CNodePtr CreateScatterNdUpdate(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update, + const CNodePtr &tensor_move) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(tensor_scatter_update); + MS_EXCEPTION_IF_NULL(tensor_move); + std::vector inputs = {NewValueNode(std::make_shared(kScatterNdUpdateOpName)), tensor_move, + tensor_scatter_update->input(2), tensor_scatter_update->input(3)}; + auto scatter_nd_update = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(scatter_nd_update); + scatter_nd_update->set_scope(tensor_scatter_update->scope()); + scatter_nd_update->set_abstract(tensor_scatter_update->abstract()); + return scatter_nd_update; +} +} // namespace + +const BaseRef TensorScatterUpdateFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto prim = std::make_shared(kTensorScatterUpdateOpName); + return VectorRef({prim, Xs}); +} + +const AnfNodePtr TensorScatterUpdateFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto tensor_scatter_update = node->cast(); + if (tensor_scatter_update == nullptr || tensor_scatter_update->size() != 4) { + return nullptr; + } + auto tensor_move = CreateTensorMove(func_graph, tensor_scatter_update); + return CreateScatterNdUpdate(func_graph, tensor_scatter_update, tensor_move); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h new file mode 100644 index 00000000000..0ada93ac708 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ + +#include "pre_activate/common/optimizer.h" + +namespace mindspore { +namespace opt { +class TensorScatterUpdateFission : public PatternProcessPass { + public: + explicit TensorScatterUpdateFission(bool multigraph = true) + : PatternProcessPass("tensor_scatter_update_fission", multigraph) {} + ~TensorScatterUpdateFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index 0e3542d1cbc..2f2471f4600 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -186,28 +186,18 @@ bool MsContext::OpenTsd() { } MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << "."; -#if (defined(ENABLE_TDTQUE) && defined(ENABLE_GE)) int32_t initStatus = tdt::TdtHostInit(device_id); if (initStatus != TDT_OK_CODE) { MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << "."; return false; } tdt_print_ = std::thread(TensorPrint()); -#endif TDT_StatusT status = tdt::TsdClient::GetInstance()->Open(device_id, rank_size); if (status != TDT_OK) { MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << "."; return false; } tsd_ref_++; -#if (defined(ENABLE_TDTQUE) && !defined(ENABLE_GE)) - int32_t initStatus = tdt::TdtHostInit(device_id); - if (initStatus != TDT_OK_CODE) { - MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << "."; - return false; - } - tdt_print_ = std::thread(TensorPrint()); -#endif MS_LOG(INFO) << "Open and init tsd successful, tsd reference = " << tsd_ref_ << "."; return true; } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 477ac350a86..e28adb6e216 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -164,6 +164,18 @@ constexpr auto kStridedReadOpName = "StridedRead"; constexpr auto kStridedWriteOpName = "StridedWrite"; constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay"; constexpr auto kFusedAdamName = "FusedAdam"; +constexpr auto kApplyAdagradV2OpName = "ApplyAdagradV2"; +constexpr auto kSparseApplyAdagradV2OpName = "SparseApplyAdagradV2"; +constexpr auto kSparseApplyFtrlOpName = "SparseApplyFtrl"; +constexpr auto kSparseApplyFtrlV2OpName = "SparseApplyFtrlV2"; +constexpr auto kApplyKerasMomentumOpName = "ApplyKerasMomentum"; +constexpr auto kSparseApplyProximalAdagradOpName = "SparseApplyProximalAdagrad"; +constexpr auto kSparseApplyRMSPropOpName = "SparseApplyRMSProp"; +constexpr auto kSparseApplyAdadeltaOpName = "SparseApplyAdadelta"; +constexpr auto kApplyAdamWithAmsgradOpName = "ApplyAdamWithAmsgrad"; +constexpr auto kTensorMoveOpName = "TensorMove"; +constexpr auto kTensorScatterUpdateOpName = "TensorScatterUpdate"; +constexpr auto kScatterNdUpdateOpName = "ScatterNdUpdate"; // attr key name constexpr auto kAttrInputNames = "input_names"; @@ -224,6 +236,9 @@ constexpr auto kAttrOutputNum = "output_num"; constexpr auto kAttrSizeSplits = "size_splits"; constexpr auto kAttrOutputDefault = "output_default"; constexpr auto kAttrPrimitiveTarget = "primitive_target"; +constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag"; +constexpr auto kAttrOffset = "offset"; +constexpr auto kAttrUseLocking = "use_locking"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; diff --git a/mindspore/ops/_op_impl/tbe/avg_pool.py b/mindspore/ops/_op_impl/tbe/avg_pool.py index 5db5947b01d..90d174474b9 100644 --- a/mindspore/ops/_op_impl/tbe/avg_pool.py +++ b/mindspore/ops/_op_impl/tbe/avg_pool.py @@ -28,8 +28,9 @@ avg_pool_op_info = TBERegOp("AvgPool") \ .attr("padding", "required", "str", "all") \ .attr("data_format", "optional", "str", "all") \ .input(0, "x", False, "required", "all") \ + .input(1, "filter", False, "optional", "all") \ .output(0, "y", False, "required", "all") \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_5HD) \ .get_op_info() diff --git a/tests/st/networks/models/bert/test_bert_graph_kernel.py b/tests/st/networks/models/bert/test_bert_graph_kernel.py index ec71cbaa4f3..4c9673e0767 100644 --- a/tests/st/networks/models/bert/test_bert_graph_kernel.py +++ b/tests/st/networks/models/bert/test_bert_graph_kernel.py @@ -126,10 +126,6 @@ class ModelCallback(Callback): print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs))) -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_onecard def test_bert_tdt(): """test bert tdt""" np.random.seed(0) diff --git a/tests/st/networks/models/bert/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/test_bert_tdt_lossscale.py index 29b4e7a5427..d4c56edbc12 100644 --- a/tests/st/networks/models/bert/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/test_bert_tdt_lossscale.py @@ -154,10 +154,6 @@ class TimeMonitor(Callback): self.epoch_mseconds_list.append(epoch_mseconds) self.per_step_mseconds_list.append(epoch_mseconds / self.data_size) -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_onecard def test_bert_percision(): """test bert percision""" context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/tensor_scatter_update_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/tensor_scatter_update_fission_test.cc new file mode 100644 index 00000000000..faebe0e4a01 --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/tensor_scatter_update_fission_test.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +class TestHWOptTensorScatterUpdateFission : public BackendCommon { + public: + TestHWOptTensorScatterUpdateFission() + : get_py_fun_("gtest_input.pre_activate.tensor_scatter_update_fission_test", true) {} + ~TestHWOptTensorScatterUpdateFission() override = default; + + UT::PyFuncGraphFetcher get_py_fun_; +}; + +TEST_F(TestHWOptTensorScatterUpdateFission, test_fission) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tensor_scatter_update_fission", "before"); + EXPECT_NE(g, nullptr); + std::vector shp1{2, 3}; + std::vector shp2{2, 2}; + std::vector shp3{2}; + auto inputx = std::make_shared(kFloat32, shp1); + auto indices = std::make_shared(kInt32, shp2); + auto update = std::make_shared(kFloat32, shp3); + AbstractBasePtrList args_spec_list{inputx, indices, update}; + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_tensor_scatter_update_fission", "after"); + EXPECT_NE(g_after, nullptr); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/tensor_scatter_update_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/tensor_scatter_update_fission_test.py new file mode 100644 index 00000000000..4a84f346072 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/tensor_scatter_update_fission_test.py @@ -0,0 +1,50 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore.ops import Primitive +from mindspore.ops import operations as P + +tensor_scatter_update = P.TensorScatterUpdate() +tensor_move = Primitive('TensorMove') +scatter_nd_update = Primitive('ScatterNdUpdate') +make_tuple = Primitive('make_tuple') +tuple_getitem = Primitive('tuple_getitem') + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_tensor_scatter_update_fission(tag): + fns = FnDict() + + @fns + def before(x, indices, updates): + res = tensor_scatter_update(x, indices, updates) + return res + + @fns + def after(x, indices, updates): + res = tensor_move(x) + res = scatter_nd_update(res, indices, updates) + return make_tuple(res) + + return fns[tag] diff --git a/tests/ut/cpp/stub/hccl/hccl_stub.cc b/tests/ut/cpp/stub/hccl/hccl_stub.cc index e25ccc36c62..56f62910f21 100644 --- a/tests/ut/cpp/stub/hccl/hccl_stub.cc +++ b/tests/ut/cpp/stub/hccl/hccl_stub.cc @@ -103,7 +103,8 @@ hcclResult_t hcom_receive(const char *tag, void *outputPtr, u64 count, hcclDataT /* 获取梯度参数切分方案 */ hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum, - u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force) { + u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force, + OriginalGraphShapeType shapeType) { return HCCL_SUCCESS; }