forked from mindspore-Ecosystem/mindspore
!4356 Add validation for field split
Merge pull request !4356 from yangzhenzhang/update-field-split
This commit is contained in:
commit
2db0290c49
|
@ -44,14 +44,15 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
|
||||||
auto device_arrangement = tensor_layout->device_arrangement().array();
|
auto device_arrangement = tensor_layout->device_arrangement().array();
|
||||||
auto tensor_map = tensor_layout->tensor_map().array();
|
auto tensor_map = tensor_layout->tensor_map().array();
|
||||||
auto slice_shape = tensor_layout->slice_shape().array();
|
auto slice_shape = tensor_layout->slice_shape().array();
|
||||||
int32_t _field_size = tensor_layout->get_field_size();
|
Shape field_size = {tensor_layout->get_field_size()};
|
||||||
Shape field_size;
|
Shape uniform_split;
|
||||||
if (_field_size != 0) {
|
if (tensor_layout->uniform_split()) {
|
||||||
field_size.push_back(_field_size);
|
uniform_split.push_back(1);
|
||||||
} else {
|
} else {
|
||||||
field_size = {0};
|
uniform_split.push_back(0);
|
||||||
}
|
}
|
||||||
std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size};
|
|
||||||
|
std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size, uniform_split};
|
||||||
dict[py::str(name)] = layout;
|
dict[py::str(name)] = layout;
|
||||||
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
|
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,92 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
|
Status GatherV2PInfo::GetManualSplitWithoutOffsetAttr() {
|
||||||
|
auto manual_split_without_offset_iter = attrs_.find("manual_split");
|
||||||
|
if (manual_split_without_offset_iter != attrs_.end()) {
|
||||||
|
manual_split_ = true;
|
||||||
|
MS_EXCEPTION_IF_NULL(manual_split_without_offset_iter->second);
|
||||||
|
if (manual_split_without_offset_iter->second->cast<ValueTuplePtr>() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Manual split without offset strategy's format is wrong! Need ValueSequeue";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
std::vector<ValuePtr> value_vector = manual_split_without_offset_iter->second->cast<ValueTuplePtr>()->value();
|
||||||
|
MS_LOG(INFO) << name_ << ": manual split with offset is " << manual_split_without_offset_iter->second->ToString();
|
||||||
|
|
||||||
|
int64_t offset = 0;
|
||||||
|
for (auto &ele : value_vector) {
|
||||||
|
index_offsets_.push_back(offset);
|
||||||
|
if (!ele->isa<Int32Imm>()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The element of manual split must be int";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
int64_t param_split_shape = static_cast<int64_t>(GetValue<int>(ele));
|
||||||
|
if (param_split_shape <= 0) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The value of manual split must be positive, but got " << param_split_shape;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
param_split_shapes_.push_back(param_split_shape);
|
||||||
|
offset += param_split_shape;
|
||||||
|
}
|
||||||
|
if (param_split_shapes_.empty()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Failed to extract param split's split info";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2PInfo::GetManualSplitAttr() {
|
||||||
|
auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset");
|
||||||
|
if (manual_split_with_offset_iter != attrs_.end()) {
|
||||||
|
manual_split_ = true;
|
||||||
|
auto var = manual_split_with_offset_iter->second->cast<ValueTuplePtr>();
|
||||||
|
if (var == nullptr) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << name_ << ": manual split with offset strategy " << var->ToString();
|
||||||
|
for (auto &ele : var->value()) {
|
||||||
|
if (!ele->isa<ValueSequeue>()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
std::vector<ValuePtr> value_vector = ele->cast<ValueTuplePtr>()->value();
|
||||||
|
if (value_vector.size() != 2) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Size of manual split with offset's element must be 2";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
int64_t param_split_row = static_cast<int64_t>(GetValue<int>(value_vector[0]));
|
||||||
|
int64_t offset = static_cast<int64_t>(GetValue<int>(value_vector[1]));
|
||||||
|
if ((param_split_row <= 0) || (offset < 0)) {
|
||||||
|
MS_LOG(ERROR) << name_
|
||||||
|
<< ": The value of param split shape must be positive, and the offset must larger or equal to 0";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
param_split_shapes_.push_back(param_split_row);
|
||||||
|
index_offsets_.push_back(offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (param_split_shapes_.empty()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Failed to extract param split with offset's split info";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Index offset must not less than 0";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (GetManualSplitWithoutOffsetAttr() != SUCCESS) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
Status GatherV2PInfo::GetAttrs() {
|
Status GatherV2PInfo::GetAttrs() {
|
||||||
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
|
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
|
||||||
if (target_ != CPU) {
|
if (target_ != CPU) {
|
||||||
|
@ -53,58 +139,76 @@ Status GatherV2PInfo::GetAttrs() {
|
||||||
if (target_iter->second->isa<StringImm>()) {
|
if (target_iter->second->isa<StringImm>()) {
|
||||||
target_ = target_iter->second->cast<StringImmPtr>()->value();
|
target_ = target_iter->second->cast<StringImmPtr>()->value();
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << name_ << " : The value of target is not a string.";
|
MS_LOG(ERROR) << name_ << ": The value of target is not a string.";
|
||||||
}
|
|
||||||
}
|
|
||||||
auto manual_split_iter = attrs_.find("manual_split");
|
|
||||||
if (manual_split_iter != attrs_.end()) {
|
|
||||||
param_split_shapes_.clear();
|
|
||||||
manual_split_ = true;
|
|
||||||
auto var = manual_split_iter->second->cast<ValueTuplePtr>();
|
|
||||||
MS_LOG(DEBUG) << "Extract manual split strategy " << manual_split_iter->second->ToString();
|
|
||||||
|
|
||||||
if (var->size() > 0) {
|
|
||||||
std::vector<ValuePtr> elements = var->value();
|
|
||||||
for (auto &ele : elements) {
|
|
||||||
if (ele->isa<ValueSequeue>()) {
|
|
||||||
auto value_tuple = ele->cast<ValueTuplePtr>();
|
|
||||||
std::vector<ValuePtr> value_vector = value_tuple->value();
|
|
||||||
if (value_vector.size() != 2) {
|
|
||||||
MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2.";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
param_split_shapes_.push_back(static_cast<int64_t>(GetValue<int>(value_vector[0])));
|
|
||||||
index_offsets_.push_back(static_cast<int64_t>(GetValue<int>(value_vector[1])));
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (param_split_shapes_.empty()) {
|
|
||||||
MS_LOG(ERROR) << "Failed to extract param split strategy.";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (GetManualSplitAttr() != SUCCESS) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (manual_split_ && (axis_ != 0)) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The axis or offset must be 0 if manual split, bug got " << axis_;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GatherV2PInfo::CheckManualSplit() {
|
Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
|
||||||
auto param_shape = inputs_shape_.at(0);
|
if (strategy.size() != 2) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
Dimensions param_strategy = strategy[0];
|
||||||
|
Dimensions indices_strategy = strategy[1];
|
||||||
|
if (param_strategy.size() != 2 || indices_strategy.size() != 2) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The size of param strategy or indices strategy must be 2";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (indices_strategy[0] != 1) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The indices_strategy[0] must be 1, bug got " << indices_strategy[0];
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (param_strategy[0] != indices_strategy[1]) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The param_strategy[0] must be equal to indices_strategy[1]";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (indices_strategy[1] != SizeToInt(param_split_shapes_.size())) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The indices_strategy[1] must be equal to manual split size";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t min_param_slice_row = inputs_shape_[1][1] / indices_strategy[1];
|
||||||
|
bool invalid = std::any_of(param_split_shapes_.begin(), param_split_shapes_.end(),
|
||||||
|
[&min_param_slice_row](int64_t v) { return v < min_param_slice_row; });
|
||||||
|
if (invalid) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The split value must be larger than or equal to indices slice's column num";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputs_shape_[0][0] < inputs_shape_[1][1]) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The param's row smaller than indices' column";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't support repeated calc
|
||||||
|
CheckGlobalDeviceManager();
|
||||||
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||||
|
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
|
||||||
|
if (IntToSize(product_p) < dev_num) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0,
|
int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0,
|
||||||
[](int64_t s, int64_t shape) { return s + shape; });
|
[](int64_t s, int64_t shape) { return s + shape; });
|
||||||
if (split_shape_sum < param_shape.at(0)) {
|
if (split_shape_sum != inputs_shape_[0][0]) {
|
||||||
MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape.";
|
MS_LOG(ERROR) << name_ << ": Sum of splited shapes must be equal to param_shape[0]";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) {
|
|
||||||
MS_LOG(ERROR) << "Failure: Index offset must not less than 0.";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,7 +251,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (manual_split_) {
|
if (manual_split_) {
|
||||||
if (CheckManualSplit() != SUCCESS) {
|
if (CheckManualSplit(strategy->GetInputDim()) != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
// when using manual_split, no need to check belowings.
|
// when using manual_split, no need to check belowings.
|
||||||
|
@ -343,14 +447,15 @@ Status GatherV2PInfo::InferTensorInfo() {
|
||||||
SUCCESS)) {
|
SUCCESS)) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (manual_split_) {
|
||||||
|
input_tensor_layout.set_uniform_split(false);
|
||||||
|
}
|
||||||
// infer tensor info
|
// infer tensor info
|
||||||
TensorInfo input_tensor_info(input_tensor_layout);
|
TensorInfo input_tensor_info(input_tensor_layout);
|
||||||
TensorInfo input_index_info(input_index_layout);
|
TensorInfo input_index_info(input_index_layout);
|
||||||
TensorInfo output_tensor_info(output_tensor_layout);
|
TensorInfo output_tensor_info(output_tensor_layout);
|
||||||
|
|
||||||
Shape slice_shape = input_tensor_info.slice_shape();
|
|
||||||
MS_LOG(DEBUG) << "The fake slice shape is: " << ShapeToString(slice_shape);
|
|
||||||
|
|
||||||
inputs_tensor_info_.push_back(input_tensor_info);
|
inputs_tensor_info_.push_back(input_tensor_info);
|
||||||
inputs_tensor_info_.push_back(input_index_info);
|
inputs_tensor_info_.push_back(input_index_info);
|
||||||
outputs_tensor_info_.push_back(output_tensor_info);
|
outputs_tensor_info_.push_back(output_tensor_info);
|
||||||
|
@ -392,9 +497,17 @@ Status GatherV2PInfo::InferBias() {
|
||||||
Status GatherV2PInfo::InferOffset() {
|
Status GatherV2PInfo::InferOffset() {
|
||||||
CheckGlobalDeviceManager();
|
CheckGlobalDeviceManager();
|
||||||
size_t rank = g_device_manager->global_rank();
|
size_t rank = g_device_manager->global_rank();
|
||||||
if (rank < index_offsets_.size()) {
|
|
||||||
index_offset_ = index_offsets_.at(rank);
|
MS_EXCEPTION_IF_NULL(strategy_);
|
||||||
MS_LOG(DEBUG) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_;
|
auto param_strategy = strategy_->GetInputDim()[0];
|
||||||
|
if (param_strategy.size() != 2) {
|
||||||
|
MS_LOG(ERROR) << "The size of param strategy must be 2";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
size_t index = rank / param_strategy[1];
|
||||||
|
if (index < index_offsets_.size()) {
|
||||||
|
index_offset_ = index_offsets_[index];
|
||||||
|
MS_LOG(INFO) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_;
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -524,8 +637,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
||||||
if (manual_split_ && target_ != CPU) {
|
if (manual_split_ && target_ != CPU) {
|
||||||
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||||
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
|
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
return replace_graph_;
|
return replace_graph_;
|
||||||
}
|
}
|
||||||
|
@ -536,8 +648,7 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
|
if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||||
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
|
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
return replace_graph_;
|
return replace_graph_;
|
||||||
}
|
}
|
||||||
|
@ -614,6 +725,13 @@ Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
|
Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
|
||||||
|
if (GetAttrs() != SUCCESS) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (manual_split_) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Manual split does not support to search strategy";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
is_auto_parallel_ = true;
|
is_auto_parallel_ = true;
|
||||||
Shape input0_split(inputs_shape_[0].size(), 1);
|
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||||
Shape input1_split(inputs_shape_[1].size(), 1);
|
Shape input1_split(inputs_shape_[1].size(), 1);
|
||||||
|
@ -621,14 +739,14 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
|
||||||
|
|
||||||
std::vector<StrategyPtr> sp_vector;
|
std::vector<StrategyPtr> sp_vector;
|
||||||
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||||
MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed.";
|
MS_LOG(ERROR) << name_ << ": Generate strategies for independent inputs() failed.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
size_t success = 0;
|
size_t success = 0;
|
||||||
for (auto &sp : sp_vector) {
|
for (auto &sp : sp_vector) {
|
||||||
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
||||||
success++;
|
success++;
|
||||||
MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy";
|
MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy";
|
||||||
PrintStrategy(sp);
|
PrintStrategy(sp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -636,6 +754,12 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() {
|
std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() {
|
||||||
|
if (GetAttrs() != SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
|
||||||
|
}
|
||||||
|
if (manual_split_) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy";
|
||||||
|
}
|
||||||
CheckGlobalDeviceManager();
|
CheckGlobalDeviceManager();
|
||||||
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||||
Dimensions param_strategy(inputs_shape_[0].size(), 1);
|
Dimensions param_strategy(inputs_shape_[0].size(), 1);
|
||||||
|
|
|
@ -59,7 +59,9 @@ class GatherV2PInfo : public OperatorInfo {
|
||||||
Status GetAttrs() override;
|
Status GetAttrs() override;
|
||||||
|
|
||||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||||
Status CheckManualSplit();
|
Status CheckManualSplit(const Strategys &strategy);
|
||||||
|
Status GetManualSplitAttr();
|
||||||
|
Status GetManualSplitWithoutOffsetAttr();
|
||||||
Status ComputeReplaceOp();
|
Status ComputeReplaceOp();
|
||||||
Status InferBias();
|
Status InferBias();
|
||||||
Status InferOffset();
|
Status InferOffset();
|
||||||
|
|
|
@ -48,6 +48,10 @@ class TensorLayout {
|
||||||
|
|
||||||
void set_field_size(int32_t field_size) { field_size_ = field_size; }
|
void set_field_size(int32_t field_size) { field_size_ = field_size; }
|
||||||
|
|
||||||
|
bool uniform_split() const { return uniform_split_; }
|
||||||
|
|
||||||
|
void set_uniform_split(bool flag) { uniform_split_ = flag; }
|
||||||
|
|
||||||
Arrangement device_arrangement() const { return device_arrangement_; }
|
Arrangement device_arrangement() const { return device_arrangement_; }
|
||||||
|
|
||||||
Map tensor_map() const { return tensor_map_; }
|
Map tensor_map() const { return tensor_map_; }
|
||||||
|
@ -104,6 +108,7 @@ class TensorLayout {
|
||||||
Arrangement tensor_shape_;
|
Arrangement tensor_shape_;
|
||||||
bool skip_redistribution_ = false;
|
bool skip_redistribution_ = false;
|
||||||
int32_t field_size_ = 0;
|
int32_t field_size_ = 0;
|
||||||
|
bool uniform_split_ = true;
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -229,10 +229,13 @@ def _load_tensor_by_layout(tensor, layout):
|
||||||
"""
|
"""
|
||||||
if not isinstance(layout, list):
|
if not isinstance(layout, list):
|
||||||
raise TypeError("The layout should be list! layout is {}".format(layout))
|
raise TypeError("The layout should be list! layout is {}".format(layout))
|
||||||
if len(layout) < 3:
|
if len(layout) < 5:
|
||||||
raise ValueError("The length of layout must be larger than 3! layout is {}".format(layout))
|
raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
|
||||||
dev_mat = layout[0]
|
dev_mat = layout[0]
|
||||||
tensor_map = layout[1]
|
tensor_map = layout[1]
|
||||||
|
uniform_split = layout[4]
|
||||||
|
if uniform_split[0] == 0:
|
||||||
|
raise RuntimeError("The load tensor only support uniform split now")
|
||||||
if tensor.size() == 1:
|
if tensor.size() == 1:
|
||||||
return tensor
|
return tensor
|
||||||
return _load_tensor(tensor, dev_mat, tensor_map)
|
return _load_tensor(tensor, dev_mat, tensor_map)
|
||||||
|
|
|
@ -49,8 +49,8 @@ def test_get_parameter_layout():
|
||||||
net.set_auto_parallel()
|
net.set_auto_parallel()
|
||||||
exe = me._executor
|
exe = me._executor
|
||||||
exe.compile(net, x, phase='train', auto_parallel_mode=True)
|
exe.compile(net, x, phase='train', auto_parallel_mode=True)
|
||||||
x_layout = [[2, 4], [1, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [1, -1]
|
x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||||
weight_layout = [[2, 4], [0, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [0, -1]
|
weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||||
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
||||||
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
|
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
|
||||||
assert net.parameter_layout_dict == expect_dict
|
assert net.parameter_layout_dict == expect_dict
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
from mindspore import context, Tensor, Parameter
|
from mindspore import context, Tensor, Parameter
|
||||||
from mindspore.common.api import _executor
|
from mindspore.common.api import _executor
|
||||||
|
@ -22,40 +23,170 @@ from mindspore.ops import operations as P
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
|
|
||||||
class Net(Cell):
|
class Net(Cell):
|
||||||
def __init__(self, strategy1=None, strategy2=None, strategy3=None):
|
def __init__(self,
|
||||||
|
strategy1=None,
|
||||||
|
strategy2=None,
|
||||||
|
strategy3=None,
|
||||||
|
axis=0,
|
||||||
|
init_flag=True,
|
||||||
|
split_tuple=(4, 4),
|
||||||
|
split_string="manual_split",
|
||||||
|
param_shape=(8, 8)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
|
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
|
||||||
self.gatherv2.add_prim_attr("manual_split", ((1, 0), (7, 1)))
|
self.gatherv2.add_prim_attr(split_string, split_tuple)
|
||||||
self.mul = P.Mul().set_strategy(strategy2)
|
self.mul = P.Mul().set_strategy(strategy2)
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.matmul = P.MatMul().set_strategy(strategy3)
|
self.matmul = P.MatMul().set_strategy(strategy3)
|
||||||
self.matmul.add_prim_attr("forward_reduce_scatter", True)
|
self.matmul.add_prim_attr("forward_reduce_scatter", True)
|
||||||
self.param = Parameter(initializer("ones", (8, 64), ms.float32), name="gatherv2_param")
|
if init_flag:
|
||||||
self.mul_weight = Parameter(initializer("ones", (2, 4, 64), ms.float32), name="mul_weight")
|
self.param = Parameter(initializer("ones", param_shape, ms.float32), name="gatherv2_param")
|
||||||
self.matmul_weight = Parameter(initializer("ones", (256, 16), ms.float32), name="matmul_weight")
|
else:
|
||||||
|
self.param = Parameter(Tensor(np.ones(param_shape), dtype=ms.float32), name="gatherv2_param")
|
||||||
|
self.mul_weight = Parameter(initializer("ones", (8, 8, 8), ms.float32), name="mul_weight")
|
||||||
|
self.matmul_weight = Parameter(initializer("ones", (64, 16), ms.float32), name="matmul_weight")
|
||||||
|
self.axis = axis
|
||||||
|
|
||||||
def construct(self, x, b):
|
def construct(self, x, b):
|
||||||
out = self.gatherv2(self.param, x, 0)
|
out = self.gatherv2(self.param, x, self.axis)
|
||||||
out = self.mul(out, self.mul_weight)
|
out = self.mul(out, self.mul_weight)
|
||||||
out = self.reshape(out, (2, 256))
|
out = self.reshape(out, (8, 64))
|
||||||
out = self.matmul(out, self.matmul_weight)
|
out = self.matmul(out, self.matmul_weight)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
_x = Tensor(np.ones([2, 4]), dtype=ms.int32)
|
|
||||||
|
_x = Tensor(np.ones([8, 8]), dtype=ms.int32)
|
||||||
_b = Tensor(np.ones([64, 8]), dtype=ms.float32)
|
_b = Tensor(np.ones([64, 8]), dtype=ms.float32)
|
||||||
|
|
||||||
|
|
||||||
def compile_net(net):
|
def compile_net(net):
|
||||||
|
context.set_context(save_graphs=True)
|
||||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
train_net = TrainOneStepCell(net, optimizer)
|
train_net = TrainOneStepCell(net, optimizer)
|
||||||
train_net.set_auto_parallel()
|
train_net.set_auto_parallel()
|
||||||
_executor.compile(train_net, _x, _b)
|
_executor.compile(train_net, _x, _b, auto_parallel_mode=True)
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
|
|
||||||
def test_neg_data_parallel():
|
|
||||||
context.set_context(save_graphs=True)
|
def test_normal_split():
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||||
strategy1 = ((2, 1), (1, 2))
|
strategy1 = ((2, 1), (1, 2))
|
||||||
strategy2 = ((1, 2, 1), (1, 2, 1))
|
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||||
strategy3 = ((1, 2), (2, 1))
|
strategy3 = ((1, 2), (2, 1))
|
||||||
net = Net(strategy1, strategy2, strategy3)
|
net = Net(strategy1, strategy2, strategy3)
|
||||||
compile_net(net)
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_normal_split2():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
|
||||||
|
strategy1 = ((4, 1), (1, 4))
|
||||||
|
strategy2 = ((1, 4, 1), (1, 4, 1))
|
||||||
|
strategy3 = ((1, 4), (4, 1))
|
||||||
|
net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8))
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_normal_split3():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=17)
|
||||||
|
strategy1 = ((4, 8), (1, 4))
|
||||||
|
strategy2 = ((1, 4, 8), (1, 4, 8))
|
||||||
|
strategy3 = ((1, 32), (32, 1))
|
||||||
|
net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8))
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_normal_split_with_offset():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||||
|
strategy1 = ((2, 1), (1, 2))
|
||||||
|
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||||
|
strategy3 = ((1, 2), (2, 1))
|
||||||
|
net = Net(strategy1, strategy2, strategy3, split_string="manual_split_with_offset", split_tuple=((4, 0), (4, 4)))
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_parallel_error():
|
||||||
|
context.set_context(save_graphs=True)
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0)
|
||||||
|
net = Net()
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_axis_error():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||||
|
strategy1 = ((2, 1), (1, 2))
|
||||||
|
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||||
|
strategy3 = ((1, 2), (2, 1))
|
||||||
|
net = Net(strategy1, strategy2, strategy3, axis=1)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy_error():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((4, 1), (8, 1))
|
||||||
|
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||||
|
strategy3 = ((1, 2), (2, 1))
|
||||||
|
net = Net(strategy1, strategy2, strategy3)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy_error2():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((4, 1), (1, 8))
|
||||||
|
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||||
|
strategy3 = ((1, 2), (2, 1))
|
||||||
|
net = Net(strategy1, strategy2, strategy3)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy_error3():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((2, 1), (1, 2))
|
||||||
|
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||||
|
strategy3 = ((1, 2), (2, 1))
|
||||||
|
net = Net(strategy1, strategy2, strategy3)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy_error4():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||||
|
strategy1 = ((2, 8), (1, 2))
|
||||||
|
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||||
|
strategy3 = ((1, 2), (2, 1))
|
||||||
|
net = Net(strategy1, strategy2, strategy3)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy_error5():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
|
||||||
|
strategy1 = ((4, 1), (1, 4))
|
||||||
|
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||||
|
strategy3 = ((1, 2), (2, 1))
|
||||||
|
net = Net(strategy1, strategy2, strategy3)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_tuple_error():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||||
|
strategy1 = ((2, 1), (1, 2))
|
||||||
|
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||||
|
strategy3 = ((1, 2), (2, 1))
|
||||||
|
net = Net(strategy1, strategy2, strategy3, split_tuple=((5, 0), (5, 5)))
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parameter_use_tensor_error():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||||
|
strategy1 = ((2, 1), (1, 2))
|
||||||
|
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||||
|
strategy3 = ((1, 2), (2, 1))
|
||||||
|
net = Net(strategy1, strategy2, strategy3, init_flag=False)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net)
|
||||||
|
|
Loading…
Reference in New Issue