forked from mindspore-Ecosystem/mindspore
!265 Synchronize Ascend software suite 07 Jul 2020
Merge pull request !265 from yanghaoran/incubator
This commit is contained in:
commit
a661545d49
|
@ -15,6 +15,7 @@ include(${GE_SOURCE_DIR}/cmake/external_libs/securec.cmake)
|
|||
if (NOT ENABLE_D)
|
||||
set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR})
|
||||
find_library(slog libslog.so ${GE_PREBUILD_PATH})
|
||||
find_library(error_manager liberror_manager.so ${GE_PREBUILD_PATH})
|
||||
elseif (DEFINED ENV{D_LINK_PATH})
|
||||
set(GE_LIB_PATH $ENV{D_LINK_PATH})
|
||||
set(GE_SYS_ARCH "")
|
||||
|
|
|
@ -156,6 +156,7 @@ if (NOT ENABLE_GE)
|
|||
set(ASCEND_PATH /usr/local/Ascend)
|
||||
endif ()
|
||||
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common)
|
||||
set(ASCEND_FWK_PATH ${ASCEND_PATH}/fwkacllib/lib64)
|
||||
|
||||
install(
|
||||
FILES
|
||||
|
@ -164,6 +165,7 @@ if (NOT ENABLE_GE)
|
|||
${CMAKE_BINARY_DIR}/graphengine/src/ge/ge_runtime/libge_runtime.so
|
||||
${ASCEND_DRIVER_PATH}/libslog.so
|
||||
${ASCEND_DRIVER_PATH}/libc_sec.so
|
||||
${ASCEND_FWK_PATH}/liberror_manager.so
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
@ -172,6 +174,7 @@ if (NOT ENABLE_GE)
|
|||
FILES
|
||||
${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so
|
||||
${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libslog.so
|
||||
${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/liberror_manager.so
|
||||
${CMAKE_SOURCE_DIR}/build/graphengine/libc_sec.so
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 8891f0546c4a250095ff68e1262f58772b938fd9
|
||||
Subproject commit 4084909d62c159da6ba316f61ad3d02a4857b34b
|
|
@ -40,7 +40,7 @@ def get_ddk_version():
|
|||
with open(backup_ddk_info_file, "r") as fp:
|
||||
ddk_version = json.load(fp)["VERSION"]
|
||||
else:
|
||||
ddk_version = "1.60.T17.B830"
|
||||
ddk_version = "Ascend910"
|
||||
return ddk_version
|
||||
|
||||
|
||||
|
|
|
@ -192,8 +192,9 @@ if (ENABLE_D)
|
|||
find_library(CCE_LIB cce ${ASCEND_RUNTIME_PATH})
|
||||
find_library(RUNTIME_LIB runtime ${ASCEND_RUNTIME_PATH})
|
||||
find_library(TSDCLIENT tsdclient HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH})
|
||||
find_library(DATATRANSFER datatransfer HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH})
|
||||
find_library(PROFILING msprof ${ASCEND_DRIVER_PATH})
|
||||
target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${PROFILING} ${HCCL} ${TSDCLIENT})
|
||||
target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${PROFILING} ${HCCL} ${DATATRANSFER})
|
||||
endif()
|
||||
|
||||
# link protobuf
|
||||
|
|
|
@ -292,7 +292,6 @@ bool TbeKernelSelect::TbeCheckSupported(
|
|||
parallel::TOPK,
|
||||
parallel::IN_TOPK,
|
||||
parallel::PACK,
|
||||
parallel::GATHER_ND,
|
||||
parallel::UNSORTEF_SEGMENT_MIND,
|
||||
parallel::UNSORTEF_SEGMENT_PRODD,
|
||||
parallel::CAST};
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h"
|
||||
#include "pre_activate/ascend/ir_fission/batch_norm_bert_fission.h"
|
||||
#include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h"
|
||||
#include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h"
|
||||
#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h"
|
||||
#include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h"
|
||||
#include "pre_activate/pass/communication_op_fusion.h"
|
||||
|
@ -149,6 +150,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<SplitFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
}
|
||||
} // namespace
|
||||
|
@ -290,6 +292,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
|||
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>());
|
||||
|
||||
optimizer->AddPassManager(ir_fusion_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
|
|
|
@ -94,7 +94,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
|
|||
origin_pair = FindRefOriginNode(input_node);
|
||||
MS_EXCEPTION_IF_NULL(origin_pair.first);
|
||||
if (!origin_pair.first->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "ref op origin node is not parameter";
|
||||
MS_LOG(WARNING) << "ref op origin node is not parameter";
|
||||
}
|
||||
MS_LOG(DEBUG) << "DealRefTransAndCast the node input index " << input_index << ", find origin op is "
|
||||
<< origin_pair.first->DebugString() << ", index is " << origin_pair.second;
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "pre_activate/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
CNodePtr CreateTensorMove(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(tensor_scatter_update);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kTensorMoveOpName)),
|
||||
tensor_scatter_update->input(1)};
|
||||
auto tensor_move = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(tensor_move);
|
||||
tensor_move->set_scope(tensor_scatter_update->scope());
|
||||
tensor_move->set_abstract(tensor_scatter_update->abstract());
|
||||
AnfAlgo::SetNodeAttr(kAttrUseLocking, MakeValue(false), tensor_move);
|
||||
return tensor_move;
|
||||
}
|
||||
|
||||
CNodePtr CreateScatterNdUpdate(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update,
|
||||
const CNodePtr &tensor_move) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(tensor_scatter_update);
|
||||
MS_EXCEPTION_IF_NULL(tensor_move);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kScatterNdUpdateOpName)), tensor_move,
|
||||
tensor_scatter_update->input(2), tensor_scatter_update->input(3)};
|
||||
auto scatter_nd_update = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(scatter_nd_update);
|
||||
scatter_nd_update->set_scope(tensor_scatter_update->scope());
|
||||
scatter_nd_update->set_abstract(tensor_scatter_update->abstract());
|
||||
return scatter_nd_update;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef TensorScatterUpdateFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto prim = std::make_shared<Primitive>(kTensorScatterUpdateOpName);
|
||||
return VectorRef({prim, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr TensorScatterUpdateFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto tensor_scatter_update = node->cast<CNodePtr>();
|
||||
if (tensor_scatter_update == nullptr || tensor_scatter_update->size() != 4) {
|
||||
return nullptr;
|
||||
}
|
||||
auto tensor_move = CreateTensorMove(func_graph, tensor_scatter_update);
|
||||
return CreateScatterNdUpdate(func_graph, tensor_scatter_update, tensor_move);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_
|
||||
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TensorScatterUpdateFission : public PatternProcessPass {
|
||||
public:
|
||||
explicit TensorScatterUpdateFission(bool multigraph = true)
|
||||
: PatternProcessPass("tensor_scatter_update_fission", multigraph) {}
|
||||
~TensorScatterUpdateFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_
|
|
@ -186,28 +186,18 @@ bool MsContext::OpenTsd() {
|
|||
}
|
||||
|
||||
MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
|
||||
#if (defined(ENABLE_TDTQUE) && defined(ENABLE_GE))
|
||||
int32_t initStatus = tdt::TdtHostInit(device_id);
|
||||
if (initStatus != TDT_OK_CODE) {
|
||||
MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << ".";
|
||||
return false;
|
||||
}
|
||||
tdt_print_ = std::thread(TensorPrint());
|
||||
#endif
|
||||
TDT_StatusT status = tdt::TsdClient::GetInstance()->Open(device_id, rank_size);
|
||||
if (status != TDT_OK) {
|
||||
MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << ".";
|
||||
return false;
|
||||
}
|
||||
tsd_ref_++;
|
||||
#if (defined(ENABLE_TDTQUE) && !defined(ENABLE_GE))
|
||||
int32_t initStatus = tdt::TdtHostInit(device_id);
|
||||
if (initStatus != TDT_OK_CODE) {
|
||||
MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << ".";
|
||||
return false;
|
||||
}
|
||||
tdt_print_ = std::thread(TensorPrint());
|
||||
#endif
|
||||
MS_LOG(INFO) << "Open and init tsd successful, tsd reference = " << tsd_ref_ << ".";
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -164,6 +164,18 @@ constexpr auto kStridedReadOpName = "StridedRead";
|
|||
constexpr auto kStridedWriteOpName = "StridedWrite";
|
||||
constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay";
|
||||
constexpr auto kFusedAdamName = "FusedAdam";
|
||||
constexpr auto kApplyAdagradV2OpName = "ApplyAdagradV2";
|
||||
constexpr auto kSparseApplyAdagradV2OpName = "SparseApplyAdagradV2";
|
||||
constexpr auto kSparseApplyFtrlOpName = "SparseApplyFtrl";
|
||||
constexpr auto kSparseApplyFtrlV2OpName = "SparseApplyFtrlV2";
|
||||
constexpr auto kApplyKerasMomentumOpName = "ApplyKerasMomentum";
|
||||
constexpr auto kSparseApplyProximalAdagradOpName = "SparseApplyProximalAdagrad";
|
||||
constexpr auto kSparseApplyRMSPropOpName = "SparseApplyRMSProp";
|
||||
constexpr auto kSparseApplyAdadeltaOpName = "SparseApplyAdadelta";
|
||||
constexpr auto kApplyAdamWithAmsgradOpName = "ApplyAdamWithAmsgrad";
|
||||
constexpr auto kTensorMoveOpName = "TensorMove";
|
||||
constexpr auto kTensorScatterUpdateOpName = "TensorScatterUpdate";
|
||||
constexpr auto kScatterNdUpdateOpName = "ScatterNdUpdate";
|
||||
|
||||
// attr key name
|
||||
constexpr auto kAttrInputNames = "input_names";
|
||||
|
@ -224,6 +236,9 @@ constexpr auto kAttrOutputNum = "output_num";
|
|||
constexpr auto kAttrSizeSplits = "size_splits";
|
||||
constexpr auto kAttrOutputDefault = "output_default";
|
||||
constexpr auto kAttrPrimitiveTarget = "primitive_target";
|
||||
constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag";
|
||||
constexpr auto kAttrOffset = "offset";
|
||||
constexpr auto kAttrUseLocking = "use_locking";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
|
|
|
@ -28,8 +28,9 @@ avg_pool_op_info = TBERegOp("AvgPool") \
|
|||
.attr("padding", "required", "str", "all") \
|
||||
.attr("data_format", "optional", "str", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "filter", False, "optional", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -126,10 +126,6 @@ class ModelCallback(Callback):
|
|||
print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bert_tdt():
|
||||
"""test bert tdt"""
|
||||
np.random.seed(0)
|
||||
|
|
|
@ -154,10 +154,6 @@ class TimeMonitor(Callback):
|
|||
self.epoch_mseconds_list.append(epoch_mseconds)
|
||||
self.per_step_mseconds_list.append(epoch_mseconds / self.data_size)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bert_percision():
|
||||
"""test bert percision"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/backend_common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TestHWOptTensorScatterUpdateFission : public BackendCommon {
|
||||
public:
|
||||
TestHWOptTensorScatterUpdateFission()
|
||||
: get_py_fun_("gtest_input.pre_activate.tensor_scatter_update_fission_test", true) {}
|
||||
~TestHWOptTensorScatterUpdateFission() override = default;
|
||||
|
||||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
TEST_F(TestHWOptTensorScatterUpdateFission, test_fission) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tensor_scatter_update_fission", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp1{2, 3};
|
||||
std::vector<int> shp2{2, 2};
|
||||
std::vector<int> shp3{2};
|
||||
auto inputx = std::make_shared<abstract::AbstractTensor>(kFloat32, shp1);
|
||||
auto indices = std::make_shared<abstract::AbstractTensor>(kInt32, shp2);
|
||||
auto update = std::make_shared<abstract::AbstractTensor>(kFloat32, shp3);
|
||||
AbstractBasePtrList args_spec_list{inputx, indices, update};
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::TensorScatterUpdateFission>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_tensor_scatter_update_fission", "after");
|
||||
EXPECT_NE(g_after, nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
tensor_scatter_update = P.TensorScatterUpdate()
|
||||
tensor_move = Primitive('TensorMove')
|
||||
scatter_nd_update = Primitive('ScatterNdUpdate')
|
||||
make_tuple = Primitive('make_tuple')
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fnDict = {}
|
||||
|
||||
def __call__(self, fn):
|
||||
self.fnDict[fn.__name__] = fn
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.fnDict[name]
|
||||
|
||||
|
||||
def test_tensor_scatter_update_fission(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(x, indices, updates):
|
||||
res = tensor_scatter_update(x, indices, updates)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(x, indices, updates):
|
||||
res = tensor_move(x)
|
||||
res = scatter_nd_update(res, indices, updates)
|
||||
return make_tuple(res)
|
||||
|
||||
return fns[tag]
|
|
@ -103,7 +103,8 @@ hcclResult_t hcom_receive(const char *tag, void *outputPtr, u64 count, hcclDataT
|
|||
|
||||
/* 获取梯度参数切分方案 */
|
||||
hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum,
|
||||
u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force) {
|
||||
u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force,
|
||||
OriginalGraphShapeType shapeType) {
|
||||
return HCCL_SUCCESS;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue