forked from mindspore-Ecosystem/mindspore
use DeviceMemory for memory control
This commit is contained in:
parent
3c307cf486
commit
f806b72447
|
@ -29,52 +29,55 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
#define DEVICE_MEMORY 1024.0 * 1024.0 * 1024.0 // 1GB
|
||||
|
||||
// Get the target node's weight for sorting.
|
||||
double GetWeights(const Graph::NodeType &node) {
|
||||
const OperatorRec &op = node.apply;
|
||||
|
||||
if (op.op_type == 0) {
|
||||
if (op.op_type == OperatorType::kRecMatMul) {
|
||||
// For MatMul
|
||||
auto cost_ptr = std::make_shared<CostMatMul>();
|
||||
|
||||
return cost_ptr->GetMinCostIn(op);
|
||||
} else if (op.op_type == 1) {
|
||||
} else if (op.op_type == OperatorType::kRecConvolution) {
|
||||
// For Convolution
|
||||
auto cost_ptr = std::make_shared<CostConvolution>();
|
||||
|
||||
return cost_ptr->GetMinCostIn(node);
|
||||
} else if (op.op_type == 2) {
|
||||
} else if (op.op_type == OperatorType::kRecPooling) {
|
||||
// For Pooling
|
||||
auto cost_ptr = std::make_shared<CostPooling>();
|
||||
|
||||
return cost_ptr->GetMinCostIn();
|
||||
} else if (op.op_type == 3) {
|
||||
} else if (op.op_type == OperatorType::kRecAdd) {
|
||||
// For Add
|
||||
auto cost_ptr = std::make_shared<CostAdd>();
|
||||
|
||||
return cost_ptr->GetMinCostIn();
|
||||
} else if (op.op_type == 4 || op.op_type == 7 || op.op_type == 9) {
|
||||
} else if (op.op_type == OperatorType::kRecSoftmax || op.op_type == OperatorType::kRecReLU ||
|
||||
op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
|
||||
// For Softmax & || Activation
|
||||
auto cost_ptr = std::make_shared<CostCommon>();
|
||||
|
||||
return cost_ptr->GetMinCostIn();
|
||||
} else if (op.op_type == 5) {
|
||||
} else if (op.op_type == OperatorType::kRecReshape) {
|
||||
// For Reshape
|
||||
auto cost_ptr = std::make_shared<CostReshape>();
|
||||
|
||||
return cost_ptr->GetMinCostIn();
|
||||
} else if (op.op_type == 6) {
|
||||
} else if (op.op_type == OperatorType::kRecBiasAdd) {
|
||||
// For BiasAdd
|
||||
auto cost_ptr = std::make_shared<CostBiasAdd>();
|
||||
|
||||
return cost_ptr->GetMinCostIn();
|
||||
} else if (op.op_type == 8) {
|
||||
} else if (op.op_type == OperatorType::kRecBatchNorm) {
|
||||
// For BatchNorm
|
||||
auto cost_ptr = std::make_shared<CostBatchNorm>();
|
||||
|
||||
return cost_ptr->GetMinCostIn();
|
||||
} else if (op.op_type == OperatorType::kRecUnkownType) {
|
||||
// For unknown type
|
||||
return 0.0;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed.";
|
||||
}
|
||||
|
@ -155,13 +158,17 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
|
|||
auto cost_ptr = std::make_shared<CostBatchNorm>();
|
||||
|
||||
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
|
||||
} else if (node.apply.op_type == 10) {
|
||||
// For unknown type
|
||||
StrategyRec default_strategy;
|
||||
return default_strategy;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Failure: Partition Operator failed.";
|
||||
}
|
||||
}
|
||||
|
||||
// Parttion graph into all devices.
|
||||
Status PartitionForAllDevices(const size_t num_device, std::shared_ptr<Graph> graph) {
|
||||
Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) {
|
||||
if (num_device < 1) {
|
||||
MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << ".";
|
||||
}
|
||||
|
@ -207,7 +214,7 @@ Status PartitionForAllDevices(const size_t num_device, std::shared_ptr<Graph> gr
|
|||
}
|
||||
|
||||
InferUndecideStrategy(graph);
|
||||
if (DevicesMemoryControl(graph) != SUCCESS) {
|
||||
if (DevicesMemoryControl(device_memory, graph) != SUCCESS) {
|
||||
return FAILED;
|
||||
} else {
|
||||
return SUCCESS;
|
||||
|
@ -306,15 +313,15 @@ void ApplyNextStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph)
|
|||
}
|
||||
}
|
||||
|
||||
Status DevicesMemoryControl(std::shared_ptr<Graph> graph) {
|
||||
Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
uint64_t iter_nodes = graph->nodes.size();
|
||||
double used_memory = 0.0;
|
||||
|
||||
for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) {
|
||||
if (graph->nodes[i_node].info == 0) {
|
||||
Graph::NodeType &Node = graph->nodes[i_node];
|
||||
double used_memory = 0.0;
|
||||
|
||||
for (int index = 0; index < 2; index++) {
|
||||
used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n *
|
||||
|
@ -329,12 +336,12 @@ Status DevicesMemoryControl(std::shared_ptr<Graph> graph) {
|
|||
Node.tensor_parm.tensor_str.str_h * Node.tensor_parm.tensor_shape.shape_h *
|
||||
Node.tensor_parm.tensor_str.str_w * Node.tensor_parm.tensor_shape.shape_w *
|
||||
GetDataTypeSize(Node.tensor_parm.tensor_type);
|
||||
if (DEVICE_MEMORY < used_memory) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Out of memory!";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (device_memory < used_memory) {
|
||||
MS_LOG(EXCEPTION) << "Failure: Out of memory!";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
|
|||
const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
|
||||
std::shared_ptr<Graph> graph);
|
||||
|
||||
Status PartitionForAllDevices(const size_t num_device, std::shared_ptr<Graph> graph);
|
||||
Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph);
|
||||
|
||||
Graph::NodeType ApplyStrToTensor(Graph::NodeType Node);
|
||||
|
||||
|
@ -50,7 +50,7 @@ void ApplyLastStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph);
|
|||
|
||||
void ApplyNextStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph);
|
||||
|
||||
Status DevicesMemoryControl(std::shared_ptr<Graph> graph);
|
||||
Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> graph);
|
||||
|
||||
size_t GetDataTypeSize(const TensorType &type);
|
||||
} // namespace parallel
|
||||
|
|
|
@ -150,14 +150,11 @@ class OperatorInfo {
|
|||
// needed by rec_parser
|
||||
void set_type(const std::string &type) { type_ = type; }
|
||||
const std::string &type() const { return type_; }
|
||||
void set_cnode_name(const std::string &cnode_name) { cnode_name_ = cnode_name; }
|
||||
const std::string &cnode_name() const { return cnode_name_; }
|
||||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||
|
||||
protected:
|
||||
// needed by rec_parser
|
||||
std::string type_;
|
||||
std::string cnode_name_;
|
||||
virtual Status CheckStrategy(const StrategyPtr &strategy) = 0;
|
||||
virtual Status InferTensorMap() = 0;
|
||||
virtual Status InferForwardCommunication() = 0;
|
||||
|
|
|
@ -935,7 +935,8 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
|
|||
std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names);
|
||||
|
||||
size_t num_device = g_device_manager->DeviceNum();
|
||||
if (PartitionForAllDevices(num_device, graph) == SUCCESS) {
|
||||
double device_memory = entire_costgraph->GetDeviceMemory();
|
||||
if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) {
|
||||
MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "PartitionForAllDevices failed.";
|
||||
|
|
|
@ -2263,13 +2263,10 @@ std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node) {
|
|||
std::vector<AnfNodePtr> all_inputs = node->inputs();
|
||||
std::vector<AnfNodePtr> node_inputs{all_inputs.begin() + 1, all_inputs.end()};
|
||||
|
||||
std::string node_id = node->UniqueId();
|
||||
name_inputs.push_back(node_id);
|
||||
for (auto &input : node_inputs) {
|
||||
std::string name;
|
||||
if (IsValueNode<Tensor>(input) || input->isa<CNode>() || input->isa<Parameter>()) {
|
||||
name = input->ToString();
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
std::string name = input->UniqueId();
|
||||
name_inputs.push_back(name);
|
||||
}
|
||||
|
||||
|
|
|
@ -227,19 +227,22 @@ TEST_F(TestPartition, test_PartitionNode) {
|
|||
|
||||
TEST_F(TestPartition, test_PartitionForAllDevices) {
|
||||
std::shared_ptr<Graph> graph = MakeMatMulData(9);
|
||||
ASSERT_EQ(PartitionForAllDevices(1024, graph), SUCCESS);
|
||||
double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0;
|
||||
ASSERT_EQ(PartitionForAllDevices(1024, device_memory, graph), SUCCESS);
|
||||
}
|
||||
|
||||
TEST_F(TestPartition, test_PartitionForAllDevices2) {
|
||||
std::shared_ptr<Graph> graph = MakeMatMulData(9);
|
||||
ASSERT_EQ(PartitionForAllDevices(2, graph), SUCCESS);
|
||||
double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0;
|
||||
ASSERT_EQ(PartitionForAllDevices(2, device_memory, graph), SUCCESS);
|
||||
}
|
||||
|
||||
// Negative case: parition on 0 device
|
||||
TEST_F(TestPartition, test_PartitionForAllDevices0) {
|
||||
std::shared_ptr<Graph> graph = MakeMatMulData(9);
|
||||
double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0;
|
||||
// Throw Exception "Number of devices can't be 0"
|
||||
EXPECT_ANY_THROW(PartitionForAllDevices(0, graph));
|
||||
EXPECT_ANY_THROW(PartitionForAllDevices(0, device_memory, graph));
|
||||
}
|
||||
|
||||
TEST_F(TestPartition, test_ApplyStrToTensor) {
|
||||
|
|
Loading…
Reference in New Issue