!39861 r1.8 to master
Merge pull request !39861 from caifubi/master-pynative-r1.8-to-master
This commit is contained in:
commit
9844b2786e
|
@ -283,6 +283,19 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
|
|||
return tensor;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateTensorMoveOp(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto prim = std::make_shared<Primitive>(kTensorMoveOpName);
|
||||
std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim), node};
|
||||
auto new_node = graph->NewCNode(new_node_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_abstract(node->abstract());
|
||||
new_node->set_scope(node->scope());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), new_node);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
bool IsAllNopNode(const session::KernelGraph *const graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto execution_order = graph->execution_order();
|
||||
|
|
|
@ -169,6 +169,8 @@ tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_pt
|
|||
|
||||
BACKEND_EXPORT tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple);
|
||||
|
||||
AnfNodePtr CreateTensorMoveOp(const FuncGraphPtr &graph, const AnfNodePtr &node);
|
||||
|
||||
bool IsAllNopNode(const session::KernelGraph *const graph);
|
||||
|
||||
void HideNopNode(session::KernelGraph *const graph);
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/common/pass/insert_tensor_move_for_communication.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool InsertTensorMoveForCommunication::Run(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(graph);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (node == nullptr || !common::AnfAlgo::IsFusedCommunicationOp(node)) {
|
||||
continue;
|
||||
}
|
||||
auto communication_op = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(communication_op);
|
||||
auto input_num = common::AnfAlgo::GetInputNum(communication_op);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input = common::AnfAlgo::GetInputNode(communication_op, i);
|
||||
// Need to insert TensorMove in these cases:
|
||||
// 1. (Parameter/ValueNode) -> CommunicationOp.
|
||||
// 2. (Parameter/ValueNode) -> NopNode -> CommunicationOp.
|
||||
// 3. (Parameter/ValueNode) -> RefNode -> CommunicationOp.
|
||||
auto real_input_with_index = common::AnfAlgo::VisitKernelWithReturnType(input, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(real_input_with_index.first);
|
||||
if (real_input_with_index.first->isa<Parameter>() || real_input_with_index.first->isa<ValueNode>() ||
|
||||
kernel_graph->IsInRefOutputMap(real_input_with_index)) {
|
||||
auto tensor_move = CreateTensorMoveOp(graph, input);
|
||||
FuncGraphManagerPtr manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->SetEdge(communication_op, SizeToInt(i) + 1, tensor_move);
|
||||
MS_LOG(DEBUG) << "Insert TensorMove for op " << communication_op->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_INSERT_TENSOR_MOVE_FOR_COMMUNICATION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_INSERT_TENSOR_MOVE_FOR_COMMUNICATION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// If the input Tensor of the graph is connected to the AllReduce operator,
|
||||
// and the input Tensor of the graph already has a device address,
|
||||
// we need to copy the data in the device address to the contiguous memory of AllReduce.
|
||||
class InsertTensorMoveForCommunication : public Pass {
|
||||
public:
|
||||
InsertTensorMoveForCommunication() : Pass("insert_tensor_move_for_communication") {}
|
||||
~InsertTensorMoveForCommunication() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_INSERT_TENSOR_MOVE_FOR_COMMUNICATION_H_
|
|
@ -454,6 +454,13 @@ std::vector<std::vector<tensor::TensorPtr>> GetRunGraphInputs(const GraphCompile
|
|||
|
||||
return input_tensor_lists;
|
||||
}
|
||||
|
||||
bool IsAutoParallel() {
|
||||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
auto parallel_mode = parallel_context->parallel_mode();
|
||||
return parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
|
||||
|
@ -1048,16 +1055,6 @@ void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCom
|
|||
|
||||
ConstructOutputs(actor_set, outputs, root_graph_);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(root_graph_);
|
||||
if (root_graph_->has_flag(kFlagIsPynativeBpropGraph)) {
|
||||
if (graph_compiler_info.device_contexts_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "RunGraph failed, actor_info " << actor_info << " has no device_context";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
graph_compiler_->DoAllReduceOnGrads(actor_info, actor_set->output_actor_->outputs(),
|
||||
graph_compiler_info.device_contexts_.front());
|
||||
}
|
||||
|
||||
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
|
||||
// Close abstract_lock for dynamic_shape
|
||||
AnfUtils::CloseAbstractLock();
|
||||
|
@ -1124,7 +1121,7 @@ void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_i
|
|||
|
||||
// Save grad node to Bucket
|
||||
if (graph->has_flag(kFlagIsPynativeBpropGraph) && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) &&
|
||||
!kernel->is_parallel()) {
|
||||
!kernel->is_parallel() && IsAutoParallel()) {
|
||||
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -294,8 +294,18 @@ void GeGraphExecutor::AllocInputHostMemory(const KernelGraphPtr &kernel_graph) c
|
|||
continue;
|
||||
}
|
||||
TypeId output_type_id = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
|
||||
|
||||
size_t tensor_size;
|
||||
if (kernel_graph->is_dynamic_shape()) {
|
||||
tensor_size = 0;
|
||||
} else {
|
||||
std::vector<size_t> shape = Convert2SizeT(common::AnfAlgo::GetOutputInferShape(input_node, 0));
|
||||
size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
|
||||
tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
|
||||
}
|
||||
|
||||
auto device_address_ptr =
|
||||
std::make_shared<cpu::CPUDeviceAddress>(nullptr, 0, kOpFormat_DEFAULT, output_type_id, kCPUDevice, 0);
|
||||
std::make_shared<cpu::CPUDeviceAddress>(nullptr, tensor_size, kOpFormat_DEFAULT, output_type_id, kCPUDevice, 0);
|
||||
device_address_ptr->set_is_ptr_persisted(false);
|
||||
AnfAlgo::SetOutputAddr(device_address_ptr, 0, input_node.get());
|
||||
}
|
||||
|
|
|
@ -24,14 +24,6 @@
|
|||
#include "utils/trace_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
bool IsPyNativeMode() {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector<ShapeVector> *hccl_kernel_intput_shape_list) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(hccl_kernel_intput_shape_list);
|
||||
|
@ -173,8 +165,7 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataTyp
|
|||
block_size = input_size;
|
||||
}
|
||||
} else {
|
||||
block_size =
|
||||
IsPyNativeMode() ? input_size : (input_size + align_size - 1 + filled_size) / align_size * align_size;
|
||||
block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
|
||||
}
|
||||
total_size = total_size + block_size;
|
||||
}
|
||||
|
|
|
@ -574,18 +574,5 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
|
|||
new_node->set_inputs(new_inputs);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateTensorMoveOp(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto prim = std::make_shared<Primitive>(kTensorMoveOpName);
|
||||
std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim), node};
|
||||
auto new_node = graph->NewCNode(new_node_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_abstract(node->abstract());
|
||||
new_node->set_scope(node->scope());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), new_node);
|
||||
return new_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -120,8 +120,6 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP
|
|||
|
||||
CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
|
||||
AnfNodePtr CreateTensorMoveOp(const FuncGraphPtr &graph, const AnfNodePtr &node);
|
||||
|
||||
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input);
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#include "backend/common/pass/communication_op_fusion.h"
|
||||
#include "backend/common/pass/replace_node_by_proxy.h"
|
||||
#include "backend/common/pass/erase_visit_attr.h"
|
||||
#include "backend/common/pass/insert_tensor_move_for_communication.h"
|
||||
#include "common/graph_kernel/adapter/graph_kernel_optimization.h"
|
||||
#include "common/graph_kernel/adapter/expander.h"
|
||||
#include "common/graph_kernel/value_graph_binder.h"
|
||||
|
@ -167,6 +168,7 @@ void CPUKernelExecutor::OptimizeGraphImpl(const KernelGraphPtr &graph) const {
|
|||
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
|
||||
pm->AddPass(std::make_shared<opt::InsertCastCPU>("insert_cast"));
|
||||
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
|
||||
pm->AddPass(std::make_shared<opt::InsertTensorMoveForCommunication>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(graph);
|
||||
graph->SetExecOrderByDefault();
|
||||
|
|
|
@ -30,6 +30,10 @@ namespace device {
|
|||
namespace gpu {
|
||||
void AssignGpuStream(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (kernel_graph->has_flag(kFlagPyNativeRunInGraph)) {
|
||||
// All operators in pynative mode use default_stream.
|
||||
return;
|
||||
}
|
||||
std::vector<CNodePtr> allreduce_kernels;
|
||||
auto execution_kernels = kernel_graph->execution_order();
|
||||
for (auto kernel_node : execution_kernels) {
|
||||
|
|
|
@ -314,6 +314,7 @@ void GPUKernelExecutor::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph)
|
|||
pm->AddPass(std::make_shared<opt::ConcatOutputsForAllGather>());
|
||||
pm->AddPass(std::make_shared<opt::GetitemTuple>());
|
||||
pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision"));
|
||||
pm->AddPass(std::make_shared<opt::InsertTensorMoveForCommunication>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(graph);
|
||||
graph->SetExecOrderByDefault();
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "backend/common/optimizer/pass_manager.h"
|
||||
#include "backend/common/optimizer/common_backend_optimization.h"
|
||||
#include "backend/common/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.h"
|
||||
#include "backend/common/pass/insert_tensor_move_for_communication.h"
|
||||
#include "plugin/device/gpu/optimizer/adam_weight_decay_fusion.h"
|
||||
#include "plugin/device/gpu/optimizer/adam_fusion.h"
|
||||
#include "plugin/device/gpu/optimizer/alltoall_fusion.h"
|
||||
|
|
|
@ -264,6 +264,24 @@ void SyncTensorTrunk(const std::vector<std::vector<TensorPtr>> &input_tensors) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateDataNodeDeviceAddressSize(const AnfNodePtr &input_node, const TensorPtr &input_tensor,
|
||||
const device::DeviceAddressPtr &device_address) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(input_node, 0);
|
||||
if (output_type_id == kTypeUnknown) {
|
||||
output_type_id = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
|
||||
}
|
||||
auto device_shape =
|
||||
trans::TransShapeToDevice(input_tensor->shape(), device_address->format(), input_node, 0, output_type_id);
|
||||
size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
|
||||
auto device_address_size = type_size * SizeOf(device_shape);
|
||||
MS_LOG(INFO) << "Size of device_address is updated from " << device_address->GetSize() << " to "
|
||||
<< device_address_size;
|
||||
device_address->SetSize(device_address_size);
|
||||
}
|
||||
} // namespace
|
||||
void DataPrepareActor::Init() {
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_info_);
|
||||
|
@ -329,8 +347,18 @@ void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_no
|
|||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (device_address->GetPtr() == nullptr) {
|
||||
// Sync tensor data size to device address for allocating the appropriate size.
|
||||
device_address->SetSize(tensor_data_size);
|
||||
if (graph->is_dynamic_shape()) {
|
||||
auto device_format = device_address->format();
|
||||
static const std::set<std::string> kNormalFormat = {
|
||||
kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
|
||||
};
|
||||
if (kNormalFormat.find(device_format) != kNormalFormat.end()) {
|
||||
device_address->SetSize(tensor_data_size);
|
||||
} else {
|
||||
// Size of 5D format device_address is larger than tensor_data_size.
|
||||
UpdateDataNodeDeviceAddressSize(input_node, input_tensor, device_address);
|
||||
}
|
||||
}
|
||||
}
|
||||
// If tensor address and device address are different (heterogeneous scenarios), or device address is persisted
|
||||
// Update device address data in data source actor process.
|
||||
|
|
|
@ -25,6 +25,8 @@ from mindspore.ops.operations.comm_ops import AllReduce, AllGather
|
|||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.api import is_pynative_parallel
|
||||
|
||||
|
||||
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
||||
|
@ -411,7 +413,10 @@ class DistributedGradReducer(Cell):
|
|||
self.ps_parameters = tuple(ps_filter(x) for x in parameters)
|
||||
self.enable_parameter_server = any(self.ps_parameters)
|
||||
self.mode = context.get_context("mode")
|
||||
self.is_pynative_parallel = is_pynative_parallel()
|
||||
self.enable_tuple_broaden = True
|
||||
|
||||
@ms_function
|
||||
def construct(self, grads):
|
||||
"""
|
||||
Under certain circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
|
||||
|
@ -426,7 +431,7 @@ class DistributedGradReducer(Cell):
|
|||
"""
|
||||
datatypes = self.map_(F.partial(_get_datatype), grads)
|
||||
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
|
||||
if self.mode == context.PYNATIVE_MODE:
|
||||
if self.is_pynative_parallel:
|
||||
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean), self.allreduce_filter, grads)
|
||||
elif self.split_fusion:
|
||||
if self.enable_parameter_server:
|
||||
|
|
Loading…
Reference in New Issue