!66219 feature-2.3-layout-extend_merge_r2.3
Merge pull request !66219 from yao_yf/feature-2.3-layout-extend_merge_r2.3
This commit is contained in:
commit
7740c81a4a
|
@ -126,6 +126,7 @@ mindspore
|
|||
.. mscnautosummary::
|
||||
:toctree: mindspore
|
||||
|
||||
mindspore.Layout
|
||||
mindspore.shard
|
||||
|
||||
即时编译
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
mindspore.Layout
|
||||
================
|
||||
|
||||
.. py:class:: mindspore.Layout(device_matrix, alias_name)
|
||||
|
||||
Layout描述了详细的分片信息。
|
||||
|
||||
.. note::
|
||||
- 仅在半自动并行或自动并行模式下有效。
|
||||
|
||||
参数:
|
||||
- **device_matrix** (tuple) - 描述设备排列的形状,其元素类型为int。
|
||||
- **alias_name** (tuple) - device_matrix的每个轴的别名,其元素类型为字符串。
|
||||
|
||||
.. py:method:: to_dict
|
||||
|
||||
将Layout转换为词典 。
|
|
@ -202,6 +202,7 @@ Parallel
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.Layout
|
||||
mindspore.shard
|
||||
|
||||
JIT
|
||||
|
|
|
@ -58,6 +58,9 @@ if(NOT BUILD_LITE)
|
|||
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.cc)
|
||||
list(APPEND MSLIB_SRC_DEPEND
|
||||
${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc)
|
||||
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/group_manager.cc)
|
||||
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/device_manager.cc)
|
||||
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/device_matrix.cc)
|
||||
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc)
|
||||
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc)
|
||||
list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc)
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <numeric>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
#include "frontend/parallel/status.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -101,6 +102,31 @@ Status DeviceMatrix::GetDevicesAlongDim(const uint64_t &dim, RankList *devices)
|
|||
return Status::FAILED;
|
||||
}
|
||||
|
||||
Status DeviceMatrix::GetDevicesAlongMultiDim(const std::vector<int64_t> &dims, RankList *devices) {
|
||||
std::set<int64_t> repeated_rank_set;
|
||||
for (const auto &dim : dims) {
|
||||
if (dim != -1) {
|
||||
auto r_dim = LongToUlong(dim);
|
||||
if (repeated_rank_set.empty()) {
|
||||
DeviceMatrix dev_matrix(rank_, dev_list_, dev_shape_);
|
||||
RankList cur_dim_reduce_list;
|
||||
dev_matrix.GetDevicesAlongDim(r_dim, &cur_dim_reduce_list);
|
||||
repeated_rank_set.insert(cur_dim_reduce_list.begin(), cur_dim_reduce_list.end());
|
||||
} else {
|
||||
auto repeated_rank_set_cpy = repeated_rank_set;
|
||||
for (const auto &rank : repeated_rank_set_cpy) {
|
||||
DeviceMatrix dev_matrix(rank, dev_list_, dev_shape_);
|
||||
RankList dim_reduce_list;
|
||||
dev_matrix.GetDevicesAlongDim(r_dim, &dim_reduce_list);
|
||||
repeated_rank_set.insert(dim_reduce_list.begin(), dim_reduce_list.end());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::copy(repeated_rank_set.begin(), repeated_rank_set.end(), std::back_inserter(*devices));
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Shape ConvertRankToCoordinate(int64_t rank, const Shape &dev_shape) {
|
||||
Shape dev_coordinate;
|
||||
for (size_t i = 0; i < dev_shape.size(); ++i) {
|
||||
|
|
|
@ -38,6 +38,7 @@ class DeviceMatrix {
|
|||
Status CreateGroupList();
|
||||
Status GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list);
|
||||
Status GetDevicesAlongDim(const uint64_t &dim, RankList *devices);
|
||||
Status GetDevicesAlongMultiDim(const std::vector<int64_t> &dims, RankList *devices);
|
||||
|
||||
private:
|
||||
int64_t rank_ = -1;
|
||||
|
|
|
@ -276,14 +276,16 @@ py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {
|
|||
} else {
|
||||
const auto &device_arrangement = tensor_layout->device_arrangement().array();
|
||||
const auto &tensor_map = tensor_layout->tensor_map().array();
|
||||
const auto &slice_shape = tensor_layout->slice_shape().array();
|
||||
const auto &slice_shape = tensor_layout->base_slice_shape().array();
|
||||
int64_t field_size = tensor_layout->get_field_size();
|
||||
bool uniform_split = tensor_layout->uniform_split();
|
||||
const std::string &opt_shard_group = tensor_layout->opt_shard_group();
|
||||
|
||||
auto [is_pipeline_shared, is_send, peer_rank, sr_tag] = GetSharedParameterInfo(para);
|
||||
const auto &before_full_shape = tensor_layout->tensor_shape_before().array();
|
||||
const auto &after_slice_shape = tensor_layout->slice_shape().array();
|
||||
py::tuple layout = py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split,
|
||||
opt_shard_group, is_pipeline_shared, is_send, peer_rank, sr_tag);
|
||||
opt_shard_group, before_full_shape, after_slice_shape,
|
||||
is_pipeline_shared, is_send, peer_rank, sr_tag);
|
||||
for (auto &name : names) {
|
||||
dict[py::str(name)] = layout;
|
||||
}
|
||||
|
@ -305,13 +307,16 @@ py::dict GetParameterLayoutFromResource(const pipeline::ResourcePtr &resource) {
|
|||
const auto &slice_shape = layout->get_slice_shape();
|
||||
int64_t field_size = layout->get_field_size();
|
||||
bool uniform_split = layout->get_uniform_split();
|
||||
std::vector<int64_t> before_full_shape;
|
||||
std::vector<int64_t> after_slice_shape;
|
||||
const std::string &opt_shard_group = layout->get_opt_shard_group();
|
||||
bool is_pipeline_shared = layout->pipeline_shared();
|
||||
bool is_send = layout->is_send();
|
||||
int64_t peer_rank = layout->peer_rank();
|
||||
int64_t sr_tag = layout->sr_tag();
|
||||
py::tuple layout_tuple = py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split,
|
||||
opt_shard_group, is_pipeline_shared, is_send, peer_rank, sr_tag);
|
||||
opt_shard_group, before_full_shape, after_slice_shape,
|
||||
is_pipeline_shared, is_send, peer_rank, sr_tag);
|
||||
dict[py::str(name)] = layout_tuple;
|
||||
}
|
||||
return dict;
|
||||
|
|
|
@ -817,6 +817,37 @@ Status SortInfo::GetAttrs() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GeLUInfo::CheckInputLayout() {
|
||||
if (inputs_tensor_info_.size() != kSizeOne) {
|
||||
MS_LOG(ERROR) << "The size of input_tensor_layout for gelu is " << inputs_tensor_info_.size() << " rather than 1.";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GeLUInfo::CheckOutputLayout() {
|
||||
if (outputs_tensor_info_.size() != kSizeOne) {
|
||||
MS_LOG(ERROR) << "The size of output_tensor_layout for gelu is " << outputs_tensor_info_.size()
|
||||
<< " rather than 1.";
|
||||
return FAILED;
|
||||
}
|
||||
if (output_infer_tensor_layout_.tensor_shape_before().array().empty()) {
|
||||
MS_LOG(ERROR) << "Parameter of output tensor layout for gelu is not allowed to be set by users.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Using output tensor layout infer by input tensor layout.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GeLUInfo::InferOutputTensorInfo() {
|
||||
output_infer_tensor_layout_ = inputs_tensor_info_[kIndex0].tensor_layout();
|
||||
TensorInfo output_tensor_info(output_infer_tensor_layout_);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GeLUInfo::InferForwardCommunicationByLayout() { return SUCCESS; }
|
||||
|
||||
REGISTER(ActivationInfo);
|
||||
REGISTER(GeLUInfo);
|
||||
REGISTER(FastGeLUInfo);
|
||||
|
|
|
@ -85,6 +85,15 @@ class GeLUInfo : public ActivationOther {
|
|||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<GeLUCost>()) {}
|
||||
~GeLUInfo() override = default;
|
||||
|
||||
protected:
|
||||
Status InferForwardCommunicationByLayout() override;
|
||||
Status CheckInputLayout() override;
|
||||
Status CheckOutputLayout() override;
|
||||
Status InferOutputTensorInfo() override;
|
||||
|
||||
private:
|
||||
TensorLayout output_infer_tensor_layout_;
|
||||
};
|
||||
|
||||
class FastGeLUInfo : public ActivationOther {
|
||||
|
|
|
@ -510,6 +510,115 @@ Status MaskedFillInfo::InferMirrorOps() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
void ExpandSmallerShapes(const Shapes *bigger_size_shapes, Shapes *smaller_size_shapes) {
|
||||
size_t insert_num = bigger_size_shapes->size() - smaller_size_shapes->size();
|
||||
Shape map_none_shape(1, MAP_NONE);
|
||||
for (size_t num = 0; num < insert_num; ++num) {
|
||||
(void)smaller_size_shapes->insert(smaller_size_shapes->cbegin(), map_none_shape);
|
||||
}
|
||||
}
|
||||
|
||||
Status AddInfo::CheckInputLayout() {
|
||||
// Check all device matrix should be the same
|
||||
if (inputs_tensor_info_.size() != kSizeTwo) {
|
||||
MS_LOG(ERROR) << "The size of input_tensor_layout for add is " << inputs_tensor_info_.size() << " rather than 2.";
|
||||
return FAILED;
|
||||
}
|
||||
auto in_layout0 = inputs_tensor_info_[kIndex0].tensor_layout();
|
||||
auto in_layout1 = inputs_tensor_info_[kIndex1].tensor_layout();
|
||||
if (in_layout0.device_arrangement_origin().array() != in_layout1.device_arrangement_origin().array()) {
|
||||
MS_LOG(ERROR) << "The device_matrix of input0 " << in_layout0.device_arrangement_origin().array()
|
||||
<< " dose not equal to device_matrix of input1 " << in_layout1.device_arrangement_origin().array();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Shapes input_shapes = InferExpandShape();
|
||||
Shape input_shape_0 = input_shapes.at(0);
|
||||
Shape input_shape_1 = input_shapes.at(1);
|
||||
|
||||
Shapes tensormap0 = in_layout0.tensor_map_before();
|
||||
Shapes tensormap1 = in_layout1.tensor_map_before();
|
||||
if (tensormap0.size() > tensormap1.size()) {
|
||||
(void)ExpandSmallerShapes(&tensormap0, &tensormap1);
|
||||
} else {
|
||||
(void)ExpandSmallerShapes(&tensormap1, &tensormap0);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_shape_0.size(); ++i) {
|
||||
if (tensormap0[i] != tensormap1[i] && input_shape_0[i] != 1 && input_shape_1[i] != 1) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
TensorLayout AddInfo::InferOutputLayout() {
|
||||
auto in_layout0 = inputs_tensor_info_[kIndex0].tensor_layout();
|
||||
auto in_layout1 = inputs_tensor_info_[kIndex1].tensor_layout();
|
||||
Shapes tensormap0 = in_layout0.tensor_map_before();
|
||||
Shapes tensormap1 = in_layout1.tensor_map_before();
|
||||
|
||||
Shapes input_shapes = InferExpandShape();
|
||||
Shape input_a_shape = input_shapes.at(0);
|
||||
Shape input_b_shape = input_shapes.at(1);
|
||||
|
||||
for (size_t i = 0; i < input_a_shape.size(); ++i) {
|
||||
input_a_shape[i] = (input_a_shape[i] == 1) ? input_b_shape[i] : input_a_shape[i];
|
||||
}
|
||||
|
||||
Shapes output_tensormap;
|
||||
Shape map_none_shape(1, MAP_NONE);
|
||||
size_t len_diff = 0;
|
||||
if (tensormap0.size() > tensormap1.size()) {
|
||||
output_tensormap = tensormap0;
|
||||
len_diff = tensormap0.size() - tensormap1.size();
|
||||
for (size_t i = 0; i < tensormap1.size(); ++i) {
|
||||
output_tensormap[i + len_diff] =
|
||||
tensormap0[i + len_diff] == map_none_shape ? tensormap1[i] : tensormap0[i + len_diff];
|
||||
}
|
||||
} else {
|
||||
output_tensormap = tensormap1;
|
||||
len_diff = tensormap1.size() - tensormap0.size();
|
||||
for (size_t i = 0; i < tensormap0.size(); ++i) {
|
||||
output_tensormap[i + len_diff] =
|
||||
tensormap1[i + len_diff] == map_none_shape ? tensormap0[i] : tensormap1[i + len_diff];
|
||||
}
|
||||
}
|
||||
|
||||
TensorLayout output_tensor_layout;
|
||||
output_tensor_layout.InitFromExtendVector(in_layout0.device_arrangement_origin().array(), output_tensormap,
|
||||
input_a_shape);
|
||||
return output_tensor_layout;
|
||||
}
|
||||
|
||||
Status AddInfo::InferOutputTensorInfo() {
|
||||
output_infer_tensor_layout_ = InferOutputLayout();
|
||||
if (output_infer_tensor_layout_.tensor_shape_before().array() != outputs_shape_[kIndex0]) {
|
||||
MS_LOG(ERROR) << "The infer output shape " << output_infer_tensor_layout_.tensor_shape_before().array()
|
||||
<< " dose not match the output shape " << outputs_shape_[kIndex0];
|
||||
return FAILED;
|
||||
}
|
||||
TensorInfo output_tensor_info(output_infer_tensor_layout_);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AddInfo::CheckOutputLayout() {
|
||||
if (outputs_tensor_info_.size() != kSizeOne) {
|
||||
MS_LOG(ERROR) << "The size of output_tensor_layout for matmul is " << outputs_tensor_info_.size()
|
||||
<< " rather than 1.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (output_infer_tensor_layout_.tensor_shape_before().array().empty()) {
|
||||
MS_LOG(ERROR) << "Parameter of output tensor layout for add is not allowed to be set by users.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Using output tensor layout infer by input tensor layout.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
REGISTER(SubInfo);
|
||||
REGISTER(AddInfo);
|
||||
REGISTER(MulInfo);
|
||||
|
|
|
@ -63,6 +63,16 @@ class AddInfo : public ArithmeticBase {
|
|||
AddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorAddCost>()) {}
|
||||
~AddInfo() override = default;
|
||||
|
||||
protected:
|
||||
Status CheckInputLayout() override;
|
||||
Status CheckOutputLayout() override;
|
||||
Status InferOutputTensorInfo() override;
|
||||
Status InferForwardCommunicationByLayout() override { return SUCCESS; }
|
||||
|
||||
private:
|
||||
TensorLayout InferOutputLayout();
|
||||
TensorLayout output_infer_tensor_layout_;
|
||||
};
|
||||
|
||||
class MulInfo : public ArithmeticBase {
|
||||
|
|
|
@ -119,6 +119,62 @@ std::vector<StrategyPtr> BiasAddInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
return sp_vector;
|
||||
}
|
||||
|
||||
Status BiasAddInfo::CheckInputLayout() {
|
||||
// Check all device matrix should be the same
|
||||
if (inputs_tensor_info_.size() != kSizeTwo) {
|
||||
MS_LOG(ERROR) << "The size of input_tensor_layout for bias_add is " << inputs_tensor_info_.size()
|
||||
<< " rather than 2.";
|
||||
return FAILED;
|
||||
}
|
||||
auto in_layout0 = inputs_tensor_info_[kIndex0].tensor_layout();
|
||||
auto in_layout1 = inputs_tensor_info_[kIndex1].tensor_layout();
|
||||
if (in_layout0.device_arrangement_origin().array() != in_layout1.device_arrangement_origin().array()) {
|
||||
MS_LOG(ERROR) << "The device_matrix of input0 " << in_layout0.device_arrangement_origin().array()
|
||||
<< " dose not equal to device_matrix of input1 " << in_layout1.device_arrangement_origin().array();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (in_layout0.tensor_map_before().back() != in_layout1.tensor_map_before()[0]) {
|
||||
MS_LOG(ERROR) << "The shard size of bias_add is not equal for last dim of input0 and input1";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status BiasAddInfo::InferOutputTensorInfo() {
|
||||
auto in_layout0 = inputs_tensor_info_[kIndex0].tensor_layout();
|
||||
// output layout should be the same as input layout 0
|
||||
if (in_layout0.tensor_shape_before().array() != outputs_shape_[kIndex0]) {
|
||||
MS_LOG(ERROR) << "The infer output shape " << in_layout0.tensor_shape_before().array()
|
||||
<< " dose not match the output shape " << outputs_shape_[kIndex0];
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
TensorLayout output_tensor_layout;
|
||||
output_tensor_layout.InitFromExtendVector(in_layout0.device_arrangement_origin().array(),
|
||||
in_layout0.tensor_map_before(), in_layout0.tensor_shape_before().array());
|
||||
|
||||
TensorInfo output_tensor_info(output_tensor_layout);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status BiasAddInfo::CheckOutputLayout() {
|
||||
if (outputs_tensor_info_.size() != kSizeOne) {
|
||||
MS_LOG(ERROR) << "The size of output_tensor_layout for bias_add is " << outputs_tensor_info_.size()
|
||||
<< " rather than 1.";
|
||||
return FAILED;
|
||||
}
|
||||
auto out_layout = outputs_tensor_info_[kIndex0].tensor_layout();
|
||||
auto in_layout0 = inputs_tensor_info_[kIndex0].tensor_layout();
|
||||
if (out_layout.tensor_map_before() != in_layout0.tensor_map_before()) {
|
||||
MS_LOG(ERROR) << "output layout of bias_add does not match the layout of first input";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Using output tensor layout infer by input tensor layout.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
REGISTER(BiasAddInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -48,6 +48,10 @@ class BiasAddInfo : public OperatorInfo {
|
|||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status CheckInputLayout() override;
|
||||
Status CheckOutputLayout() override;
|
||||
Status InferOutputTensorInfo() override;
|
||||
Status InferForwardCommunicationByLayout() override { return SUCCESS; }
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1316,7 +1316,9 @@ Status GatherInfo::ComputeReplaceOp() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
|
||||
Status GatherInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts) {
|
||||
if (InitWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||
return FAILED;
|
||||
|
|
|
@ -307,7 +307,9 @@ class GatherInfo : public OperatorInfo {
|
|||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherCost>()),
|
||||
replace_op_name_(replace_op_name) {}
|
||||
~GatherInfo() override = default;
|
||||
Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
|
||||
Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts = {},
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts = {}) override;
|
||||
Status InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
|
||||
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
|
|
|
@ -266,6 +266,131 @@ Status LayerNormInfo::InitShapes() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status LayerNormInfo::CheckInputLayout() {
|
||||
// Check all device matrix should be the same
|
||||
if (inputs_tensor_info_.size() != kSizeThree) {
|
||||
MS_LOG(ERROR) << "The size of input_tensor_layout for layernorm is " << inputs_tensor_info_.size()
|
||||
<< " rather than 3.";
|
||||
return FAILED;
|
||||
}
|
||||
auto in_layout = inputs_tensor_info_[kIndex0].tensor_layout();
|
||||
auto gamma_layout = inputs_tensor_info_[kIndex1].tensor_layout();
|
||||
auto beta_layout = inputs_tensor_info_[kIndex2].tensor_layout();
|
||||
|
||||
// check input layout
|
||||
// [begin_norm_axis_, -1] should not shard after begin_norm_axis
|
||||
const std::vector<int64_t> np_split_map = {-1};
|
||||
for (size_t i = begin_norm_axis_; i < in_layout.tensor_map_before().size(); ++i) {
|
||||
if (in_layout.tensor_map_before()[i] != np_split_map) {
|
||||
MS_LOG(ERROR) << "Layernorm Invalid input layout " << in_layout.tensor_map_before();
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
// check gamma and beta layout
|
||||
if (gamma_layout.tensor_map_before() != beta_layout.tensor_map_before()) {
|
||||
MS_LOG(ERROR) << "The tensor map of gamma " << gamma_layout.tensor_map_before()
|
||||
<< " dose not equal to tensor map of beta " << beta_layout.tensor_map_before();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
size_t gamma_diff = in_layout.tensor_map_before().size() - gamma_layout.tensor_map_before().size();
|
||||
for (size_t j = 0; j < gamma_layout.tensor_map_before().size(); ++j) {
|
||||
if (gamma_layout.tensor_map_before()[j] != in_layout.tensor_map_before()[gamma_diff + j]) {
|
||||
MS_LOG(ERROR) << "Layernorm Invalid gamma layout " << gamma_layout.tensor_map_before();
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
size_t beta_diff = in_layout.tensor_map_before().size() - beta_layout.tensor_map_before().size();
|
||||
for (size_t j = 0; j < beta_layout.tensor_map_before().size(); ++j) {
|
||||
if (beta_layout.tensor_map_before()[j] != in_layout.tensor_map_before()[beta_diff + j]) {
|
||||
MS_LOG(ERROR) << "Layernorm Invalid beta layout " << beta_layout.tensor_map_before();
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status LayerNormInfo::CheckOutputLayout() {
|
||||
// Check all device matrix should be the same
|
||||
if (outputs_tensor_info_.size() != kSizeThree) {
|
||||
MS_LOG(ERROR) << "The size of output_tensor_layout for layernorm is " << outputs_tensor_info_.size()
|
||||
<< " rather than 3.";
|
||||
return FAILED;
|
||||
}
|
||||
if (output_infer_tensor_layout_.tensor_shape_before().array().empty()) {
|
||||
MS_LOG(ERROR) << "Parameter of output tensor layout for layernorm is not allowed to be set by users.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Using output tensor layout infer by input tensor layout.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status LayerNormInfo::InferOutputLayout() {
|
||||
auto input_layout = inputs_tensor_info_[kIndex0].tensor_layout();
|
||||
|
||||
TensorLayout output_tensor_layout;
|
||||
TensorLayout mean_tensor_layout;
|
||||
TensorLayout var_tensor_layout;
|
||||
output_tensor_layout = input_layout;
|
||||
mean_tensor_layout = output_tensor_layout;
|
||||
std::vector<Shape> mean_extended_tensor_map;
|
||||
Shape mean_tensor_shape;
|
||||
|
||||
for (size_t i = 0; i < mean_tensor_layout.tensor_shape_before().array().size(); ++i) {
|
||||
auto map_dim = input_layout.tensor_map_before()[i];
|
||||
auto shp_dim = input_layout.tensor_shape_before().array()[i];
|
||||
mean_extended_tensor_map.push_back(map_dim);
|
||||
if (i < begin_norm_axis_) {
|
||||
mean_tensor_shape.push_back(shp_dim);
|
||||
} else {
|
||||
mean_tensor_shape.push_back(1);
|
||||
}
|
||||
}
|
||||
mean_tensor_layout.InitFromExtendVector(mean_tensor_layout.device_arrangement_origin().array(),
|
||||
mean_extended_tensor_map, mean_tensor_shape);
|
||||
var_tensor_layout = mean_tensor_layout;
|
||||
|
||||
output_infer_tensor_layout_ = output_tensor_layout;
|
||||
mean_infer_tensor_layout_ = mean_tensor_layout;
|
||||
var_infer_tensor_layout_ = var_tensor_layout;
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status LayerNormInfo::InferOutputTensorInfo() {
|
||||
InferOutputLayout();
|
||||
if (output_infer_tensor_layout_.tensor_shape_before().array() != outputs_shape_[kIndex0]) {
|
||||
MS_LOG(ERROR) << "The infer output shape " << output_infer_tensor_layout_.tensor_shape_before().array()
|
||||
<< " dose not match the output shape " << outputs_shape_[kIndex0];
|
||||
return FAILED;
|
||||
}
|
||||
if (mean_infer_tensor_layout_.tensor_shape_before().array() != outputs_shape_[kIndex1]) {
|
||||
MS_LOG(ERROR) << "The infer output mean shape " << mean_infer_tensor_layout_.tensor_shape_before().array()
|
||||
<< " dose not match the output shape " << outputs_shape_[kIndex1];
|
||||
return FAILED;
|
||||
}
|
||||
if (var_infer_tensor_layout_.tensor_shape_before().array() != outputs_shape_[kIndex2]) {
|
||||
MS_LOG(ERROR) << "The infer output var shape " << var_infer_tensor_layout_.tensor_shape_before().array()
|
||||
<< " dose not match the output shape " << outputs_shape_[kIndex2];
|
||||
return FAILED;
|
||||
}
|
||||
TensorInfo output_tensor_info(output_infer_tensor_layout_);
|
||||
TensorInfo mean_tensor_info(mean_infer_tensor_layout_);
|
||||
TensorInfo var_tensor_info(var_infer_tensor_layout_);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
outputs_tensor_info_.push_back(mean_tensor_info);
|
||||
outputs_tensor_info_.push_back(var_tensor_info);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status LayerNormInfo::InferForwardCommunicationByLayout() {
|
||||
// for layernorm, no ForwardCommunication
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
REGISTER(LayerNormInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -58,12 +58,20 @@ class LayerNormInfo : public OperatorInfo {
|
|||
Status GenerateGammaAndBetaStrategies(const std::vector<StrategyPtr> &sp_vector);
|
||||
Status InitShapes();
|
||||
Status InferMirrorOps() override;
|
||||
Status InferOutputTensorInfo() override;
|
||||
Status InferForwardCommunicationByLayout() override;
|
||||
Status CheckInputLayout() override;
|
||||
Status CheckOutputLayout() override;
|
||||
|
||||
private:
|
||||
size_t begin_norm_axis_;
|
||||
Shape input_shape_;
|
||||
Shape gamma_shape_;
|
||||
Shape beta_shape_;
|
||||
Status InferOutputLayout();
|
||||
TensorLayout output_infer_tensor_layout_;
|
||||
TensorLayout mean_infer_tensor_layout_;
|
||||
TensorLayout var_infer_tensor_layout_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -422,6 +422,186 @@ Status MatMul::InferOutputTensorMap() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MatMul::CheckInputLayout() {
|
||||
// Check all device matrix should be the same
|
||||
if (inputs_tensor_info_.size() != kSizeTwo) {
|
||||
MS_LOG(ERROR) << "The size of input_tensor_layout for matmul is " << inputs_tensor_info_.size()
|
||||
<< " rather than 2.";
|
||||
return FAILED;
|
||||
}
|
||||
auto in_layout0 = inputs_tensor_info_[kIndex0].tensor_layout();
|
||||
auto in_layout1 = inputs_tensor_info_[kIndex1].tensor_layout();
|
||||
if (in_layout0.device_arrangement_origin().array() != in_layout1.device_arrangement_origin().array()) {
|
||||
MS_LOG(ERROR) << "The device_matrix of input0 " << in_layout0.device_arrangement_origin().array()
|
||||
<< " dose not equal to device_matrix of input1 " << in_layout1.device_arrangement_origin().array();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
size_t axis0 = in_layout0.tensor_shape_before().array().size() - 1;
|
||||
if (transpose_a_) {
|
||||
axis0--;
|
||||
}
|
||||
size_t axis1 = in_layout0.tensor_shape_before().array().size() - 2;
|
||||
if (transpose_b_) {
|
||||
axis1++;
|
||||
}
|
||||
if (in_layout0.tensor_map_before()[axis0] != in_layout1.tensor_map_before()[axis1]) {
|
||||
MS_LOG(ERROR) << "The shard size of reduce_dim is not equal for input0 and input1";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MatMul::CheckOutputLayout() {
|
||||
// Check all device matrix should be the same
|
||||
if (outputs_tensor_info_.size() != kSizeOne) {
|
||||
MS_LOG(ERROR) << "The size of output_tensor_layout for matmul is " << outputs_tensor_info_.size()
|
||||
<< " rather than 1.";
|
||||
return FAILED;
|
||||
}
|
||||
auto out_layout = outputs_tensor_info_[kIndex0].tensor_layout();
|
||||
if (!output_infer_tensor_layout_.tensor_shape_before().array().empty()) {
|
||||
MS_LOG(INFO) << "Using output tensor layout infer by input tensor layout.";
|
||||
return SUCCESS;
|
||||
}
|
||||
output_infer_tensor_layout_ = InferOutputLayout();
|
||||
if (output_infer_tensor_layout_ == out_layout) {
|
||||
MS_LOG(INFO)
|
||||
<< "output tensor layout infer by input tensor layout is same with user configured output tensor layout.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
auto input_layout0 = inputs_tensor_info_[kIndex0].tensor_layout();
|
||||
int64_t axis0 = input_layout0.tensor_shape_before().array().size() - 1;
|
||||
if (transpose_a_) {
|
||||
axis0 -= 1;
|
||||
}
|
||||
auto output_extended_tensor_map = output_infer_tensor_layout_.tensor_map_before();
|
||||
auto axis_map = input_layout0.tensor_map_before()[axis0];
|
||||
output_extended_tensor_map[0].insert(output_extended_tensor_map[0].end(), axis_map.begin(), axis_map.end());
|
||||
TensorLayout reduce_scatter_out_layout;
|
||||
reduce_scatter_out_layout.InitFromExtendVector(output_infer_tensor_layout_.device_arrangement_origin().array(),
|
||||
output_extended_tensor_map,
|
||||
output_infer_tensor_layout_.tensor_shape_before().array());
|
||||
if (reduce_scatter_out_layout != out_layout) {
|
||||
MS_LOG(ERROR) << "The user configured output layout { device_matrix:"
|
||||
<< out_layout.device_arrangement_origin().array() << ", tensor_map:" << out_layout.tensor_map_before()
|
||||
<< ", tensor_shape:" << out_layout.tensor_shape_before().array()
|
||||
<< " } dose not match the inferred output layout { device_matrix:"
|
||||
<< output_infer_tensor_layout_.device_arrangement_origin().array()
|
||||
<< ", tensor_map:" << output_infer_tensor_layout_.tensor_map_before()
|
||||
<< ", tensor_shape:" << output_infer_tensor_layout_.tensor_shape_before().array()
|
||||
<< " } (using all_reduce) or { device_matrix:"
|
||||
<< reduce_scatter_out_layout.device_arrangement_origin().array()
|
||||
<< ", tensor_map:" << reduce_scatter_out_layout.tensor_map_before()
|
||||
<< ", tensor_shape:" << reduce_scatter_out_layout.tensor_shape_before().array()
|
||||
<< " } (using reduce_scatter)";
|
||||
return FAILED;
|
||||
}
|
||||
forward_reduce_scatter_ = true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
TensorLayout MatMul::InferOutputLayout() {
|
||||
auto input_layout0 = inputs_tensor_info_[kIndex0].tensor_layout();
|
||||
auto input_layout1 = inputs_tensor_info_[kIndex1].tensor_layout();
|
||||
size_t axis0 = input_layout0.tensor_shape_before().array().size() - 1;
|
||||
if (transpose_a_) {
|
||||
axis0 -= 1;
|
||||
}
|
||||
std::vector<Shape> output_extended_tensor_map;
|
||||
Shape output_tensor_shape;
|
||||
|
||||
for (size_t i = 0; i < input_layout0.tensor_shape_before().array().size(); ++i) {
|
||||
auto map_dim = input_layout0.tensor_map_before()[i];
|
||||
auto shp_dim = input_layout0.tensor_shape_before().array()[i];
|
||||
if (i != axis0) {
|
||||
output_extended_tensor_map.push_back(map_dim);
|
||||
output_tensor_shape.push_back(shp_dim);
|
||||
}
|
||||
}
|
||||
|
||||
if (!transpose_b_) {
|
||||
output_extended_tensor_map.push_back(input_layout1.tensor_map_before()[inputs_shape_[kIndex1].size() - 1]);
|
||||
output_tensor_shape.push_back(input_layout1.tensor_shape_before().GetDimByIdx(inputs_shape_[kIndex1].size() - 1));
|
||||
} else {
|
||||
output_extended_tensor_map.push_back(input_layout1.tensor_map_before()[inputs_shape_[kIndex1].size() - 2]);
|
||||
output_tensor_shape.push_back(input_layout1.tensor_shape_before().GetDimByIdx(inputs_shape_[kIndex1].size() - 2));
|
||||
}
|
||||
|
||||
TensorLayout output_tensor_layout;
|
||||
output_tensor_layout.InitFromExtendVector(input_layout0.device_arrangement_origin().array(),
|
||||
output_extended_tensor_map, output_tensor_shape);
|
||||
return output_tensor_layout;
|
||||
}
|
||||
|
||||
Status MatMul::InferOutputTensorInfo() {
|
||||
output_infer_tensor_layout_ = InferOutputLayout();
|
||||
if (output_infer_tensor_layout_.tensor_shape_before().array() != outputs_shape_[kIndex0]) {
|
||||
MS_LOG(ERROR) << "The infer output shape " << output_infer_tensor_layout_.tensor_shape_before().array()
|
||||
<< " dose not match the output shape " << outputs_shape_[kIndex0];
|
||||
return FAILED;
|
||||
}
|
||||
TensorInfo output_tensor_info(output_infer_tensor_layout_);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MatMul::InferForwardCommunicationByLayout() {
|
||||
forward_op_.clear();
|
||||
auto input_layout0 = inputs_tensor_info_[kIndex0].tensor_layout();
|
||||
size_t axis0 = input_layout0.tensor_shape_before().array().size() - 1;
|
||||
if (transpose_a_) {
|
||||
axis0 -= 1;
|
||||
}
|
||||
auto axis_tensor_map = input_layout0.tensor_map_before()[axis0];
|
||||
int64_t axis_shard = 1;
|
||||
std::vector<int64_t> r_dim_vector;
|
||||
for (const auto &dim : axis_tensor_map) {
|
||||
if (dim == -1) {
|
||||
continue;
|
||||
}
|
||||
int64_t divisor = input_layout0.device_arrangement_origin().GetDimByReverseIdx(LongToUlong(dim));
|
||||
axis_shard *= divisor;
|
||||
auto r_dim = SizeToLong(input_layout0.device_arrangement_origin().array().size() - 1) - dim;
|
||||
r_dim_vector.push_back(r_dim);
|
||||
}
|
||||
// Relevant dimension is not split and all reduce is not required,
|
||||
if (axis_shard == MIN_SLICE_NUM) {
|
||||
MS_LOG(INFO) << name_ << ": Forward communication is not required.";
|
||||
return SUCCESS;
|
||||
}
|
||||
RankList repeated_rank_list;
|
||||
auto device_matrix = DeviceMatrix(g_device_manager->global_rank(), g_device_manager->GetDeviceListInThisStage(),
|
||||
input_layout0.device_arrangement_origin().array());
|
||||
if (device_matrix.GetDevicesAlongMultiDim(r_dim_vector, &repeated_rank_list) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer Forward communication by multi axis failed.";
|
||||
return FAILED;
|
||||
}
|
||||
if (repeated_rank_list.size() == 1) {
|
||||
MS_LOG(INFO) << name_ << ": Forward communication is not required.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Group forward_group;
|
||||
if (g_device_manager->CreateGroup(repeated_rank_list, &forward_group) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": Create communication group by tensor_map failed, the rank_list is: " << repeated_rank_list
|
||||
<< ", the full_name of node is: " << cnode_->fullname_with_scope();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Operator op;
|
||||
if (forward_reduce_scatter_) {
|
||||
op = CreateReduceScatterOp(REDUCE_OP_SUM, forward_group.name());
|
||||
} else {
|
||||
op = CreateAllReduceOp(REDUCE_OP_SUM, forward_group.name());
|
||||
}
|
||||
forward_op_.push_back(op);
|
||||
MS_LOG(INFO) << name_ << ": The group name of forward communication is " << forward_group.name();
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MatMul::CheckLayoutConfig() {
|
||||
if (CheckInputStrategy(strategy_from_layout_[0], strategy_from_layout_[1]) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": invalid layout config, the dev matrix is " << dev_matrix_shape_
|
||||
|
|
|
@ -73,11 +73,17 @@ class MatMul : public MatMulBase {
|
|||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status CheckOutputStrategy(const StrategyPtr &out_strategy) override;
|
||||
Status InferOutputTensorMap() override;
|
||||
Status InferOutputTensorInfo() override;
|
||||
Status InferForwardCommunicationByLayout() override;
|
||||
Status CheckLayoutConfig() override;
|
||||
Status CheckInputLayout() override;
|
||||
Status CheckOutputLayout() override;
|
||||
|
||||
private:
|
||||
void CheckPCLMatMul(const Shape &mat_a_strategy, const Shape &mat_b_strategy);
|
||||
Status CheckInputStrategy(const Shape &mat_a_strategy, const Shape &mat_b_strategy);
|
||||
TensorLayout InferOutputLayout();
|
||||
TensorLayout output_infer_tensor_layout_;
|
||||
};
|
||||
|
||||
class MatMulInfo : public MatMul {
|
||||
|
|
|
@ -404,6 +404,53 @@ Status OperatorInfo::InferMirrorOps() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status OperatorInfo::InferMirrorOpsByLayout() {
|
||||
mirror_ops_.clear();
|
||||
if (inputs_shape_.empty()) {
|
||||
MS_LOG(INFO) << name_ << ": The inputs size is empty";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
bool group_is_empty = true;
|
||||
for (size_t i = 0; i < inputs_tensor_info_.size(); ++i) {
|
||||
auto input_tensor_layout = inputs_tensor_info_[i].tensor_layout();
|
||||
auto repeated_rank_list = input_tensor_layout.InferRepeatedGroup();
|
||||
|
||||
OperatorVector mirror_op;
|
||||
if (repeated_rank_list.size() == 1) {
|
||||
MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i;
|
||||
mirror_ops_.push_back(mirror_op);
|
||||
continue;
|
||||
}
|
||||
if (is_auto_parallel_) {
|
||||
if (g_device_manager->CheckDeviceList(repeated_rank_list) != SUCCESS) {
|
||||
MS_LOG(INFO) << name_ << ": Try to create communication group : " << repeated_rank_list
|
||||
<< " failed in auto parallel mode, "
|
||||
"this error can be ignored in parallel strategies searching step";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Group mirror_group;
|
||||
if (g_device_manager->CreateGroup(repeated_rank_list, &mirror_group) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": Create communication group by tensor_map failed, the rank_list is: " << repeated_rank_list
|
||||
<< ", the full_name of node is: " << cnode_->fullname_with_scope();
|
||||
return FAILED;
|
||||
}
|
||||
group_is_empty = false;
|
||||
mirror_op = CreateMirrorOps(mirror_group.name(), mirror_group.GetDevNum());
|
||||
mirror_ops_.push_back(mirror_op);
|
||||
}
|
||||
|
||||
if (group_is_empty) {
|
||||
mirror_ops_.clear();
|
||||
MS_LOG(INFO) << name_ << ": No need to insert mirror ops";
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status OperatorInfo::InferTensorInfo() {
|
||||
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid args";
|
||||
|
@ -683,6 +730,8 @@ Operator CreateMicroStepAllGatherOp(const std::string &group) {
|
|||
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) {
|
||||
Shape tensor_map = tensor_layout.tensor_map().array();
|
||||
Shape dev_matrix_shape = tensor_layout.device_arrangement().array();
|
||||
Shape slice_shape = tensor_layout.base_slice_shape().array();
|
||||
Shape full_shape = tensor_layout.tensor_shape().array();
|
||||
OperatorName operator_name = GET_TENSOR_SLICE;
|
||||
|
||||
OperatorAttrs attrs;
|
||||
|
@ -690,7 +739,11 @@ Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) {
|
|||
Param dev_mat_param = std::make_pair(std::make_pair(DEV_MAT, dev_mat_value), 2);
|
||||
ValuePtr tensor_map_value = MakeValue(tensor_map);
|
||||
Param tensor_map_param = std::make_pair(std::make_pair(TENSOR_MAP, tensor_map_value), 3);
|
||||
OperatorParams params = {dev_mat_param, tensor_map_param};
|
||||
ValuePtr slice_shape_value = MakeValue(slice_shape);
|
||||
Param slice_shape_param = std::make_pair(std::make_pair(SLICE_SHAPE, slice_shape_value), 4);
|
||||
ValuePtr full_shape_value = MakeValue(full_shape);
|
||||
Param full_shape_param = std::make_pair(std::make_pair(FULL_SHAPE, full_shape_value), 5);
|
||||
OperatorParams params = {dev_mat_param, tensor_map_param, slice_shape_param, full_shape_param};
|
||||
OperatorArgs operator_arg = std::make_pair(attrs, params);
|
||||
|
||||
Operator op = std::make_pair(operator_name, operator_arg);
|
||||
|
@ -784,7 +837,7 @@ Status OperatorInfo::CreateGroupForOptShard(TensorLayout *tensor_layout, std::ve
|
|||
}
|
||||
CheckGlobalDeviceManager();
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
|
||||
DeviceMatrix dev_matrix(rank, stage_device_list_, tensor_layout->device_arrangement_origin().array());
|
||||
RankList group_devices;
|
||||
Shape tensor_map = tensor_layout->origin_tensor_map().array();
|
||||
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
|
||||
|
@ -927,7 +980,13 @@ Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) {
|
|||
return slice_shape;
|
||||
}
|
||||
|
||||
Status OperatorInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
|
||||
Status OperatorInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts) {
|
||||
if (!in_tensor_layouts.empty()) {
|
||||
return InitWithTensorLayout(in_tensor_layouts, out_tensor_layouts);
|
||||
}
|
||||
|
||||
if (InitWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Init failed.";
|
||||
return FAILED;
|
||||
|
@ -1057,6 +1116,61 @@ Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &in_strategy, cons
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status OperatorInfo::InitWithTensorLayout(const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts) {
|
||||
ResetQueueMember();
|
||||
|
||||
if (InferAttrs() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": InferAttrs failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
for (const auto &input_layout : in_tensor_layouts) {
|
||||
TensorInfo input_tensor_info(*input_layout);
|
||||
inputs_tensor_info_.push_back(input_tensor_info);
|
||||
}
|
||||
if (CheckInputLayout() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": CheckInputLayout failed.";
|
||||
return FAILED;
|
||||
}
|
||||
for (const auto &output_layout : out_tensor_layouts) {
|
||||
TensorInfo output_tensor_info(*output_layout);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
}
|
||||
|
||||
if (outputs_tensor_info_.size() != outputs_shape_.size()) {
|
||||
outputs_tensor_info_.clear();
|
||||
// Need be override
|
||||
if (InferOutputTensorInfo() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": InferOutputTensorLayout failed.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (outputs_tensor_info_.size() != outputs_shape_.size()) {
|
||||
MS_LOG(ERROR) << name_ << ": the output tensor layout num " << outputs_tensor_info_.size()
|
||||
<< " dose not match the output num " << outputs_shape_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (CheckOutputLayout() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": CheckLayout failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// Need be override
|
||||
if (InferForwardCommunicationByLayout() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (InferMirrorOpsByLayout() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": InferMirrorOps failed.";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<Edge>> OperatorInfo::GetAliveSuccEdges() {
|
||||
std::vector<std::shared_ptr<Edge>> ret;
|
||||
for (auto &edge : succ_edges_) {
|
||||
|
|
|
@ -98,7 +98,9 @@ class OperatorInfo {
|
|||
// If output is tuple, outputs_type.size() is greater than 1.
|
||||
Status set_outputs_type(const std::vector<TypePtr> &outputs_type);
|
||||
const std::vector<TypePtr> &outputs_type() const { return outputs_type_; }
|
||||
virtual Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
|
||||
virtual Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts = {},
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts = {});
|
||||
// only init the necessary parts
|
||||
virtual Status InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
|
||||
|
||||
|
@ -241,12 +243,18 @@ class OperatorInfo {
|
|||
virtual Status CheckStrategy(const StrategyPtr &strategy) = 0;
|
||||
virtual Status InferTensorMap() = 0;
|
||||
virtual Status InferOutputTensorMap() { return SUCCESS; }
|
||||
virtual Status InferOutputTensorInfo() { return SUCCESS; }
|
||||
virtual Status CheckLayoutConfig() { return SUCCESS; }
|
||||
virtual Status CheckInputLayout() { return SUCCESS; }
|
||||
virtual Status CheckOutputLayout() { return SUCCESS; }
|
||||
virtual Status InferForwardCommunicationByLayout() { return SUCCESS; }
|
||||
virtual Status InferMirrorOpsByLayout();
|
||||
virtual Status InferForwardCommunication() = 0;
|
||||
virtual Status GetAttrs() = 0;
|
||||
virtual Status InferDevMatrixShape() = 0;
|
||||
virtual Status InferMirrorOps();
|
||||
virtual Status InferTensorInfo();
|
||||
|
||||
virtual void InferReplaceOps() {}
|
||||
virtual Status CheckOutputStrategy(const StrategyPtr &out_strategy);
|
||||
Status CheckStrategyByVector(const Shapes &strategy, const Shapes &inputs_shape);
|
||||
|
@ -257,6 +265,8 @@ class OperatorInfo {
|
|||
virtual Status InferAttrs();
|
||||
void ResetQueueMember();
|
||||
Status InitWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
|
||||
Status InitWithTensorLayout(const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts);
|
||||
Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
|
||||
Status InferRepeatedCalcInfo();
|
||||
Status InferVirtualDivOps();
|
||||
|
|
|
@ -152,6 +152,9 @@ constexpr char RESHAPEINFO[] = "ReshapeInfo";
|
|||
constexpr char GETNEXTINFO[] = "GetNextInfo";
|
||||
constexpr char VIRTUALDATASETINFO[] = "VirtualDatasetInfo";
|
||||
constexpr char FUNC_PARAM[] = "func_param";
|
||||
constexpr char IN_LAYOUT[] = "in_layout";
|
||||
constexpr char OUT_LAYOUT[] = "out_layout";
|
||||
constexpr char DEVICE_MATRIX[] = "device_matrix";
|
||||
|
||||
constexpr char RELU_TYPE[] = "relu";
|
||||
constexpr char RELU6_TYPE[] = "relu6";
|
||||
|
@ -646,6 +649,7 @@ constexpr char SEND_REC_DEPEND[] = "send_receive_depend";
|
|||
constexpr char USER_NODE_STAGE[] = "user_node_stage";
|
||||
constexpr char NODE_STAGE[] = "node_stage";
|
||||
constexpr char SLICE_SHAPE[] = "slice_shape";
|
||||
constexpr char FULL_SHAPE[] = "full_shape";
|
||||
constexpr char SLICE_DTYPE[] = "slice_dtype";
|
||||
constexpr char INPUT_PARAM[] = "input_param";
|
||||
constexpr char ORIGIN_INPUT_IS_PARAM[] = "origin_input_is_param";
|
||||
|
|
|
@ -106,11 +106,6 @@ Shape QuantBatchMatmulInfo::GetCommonShape(const Dimensions &x1_strategy, const
|
|||
}
|
||||
|
||||
Status QuantBatchMatmulInfo::GetAttrs() {
|
||||
if (attrs_.size() < MATMUL_ATTRS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of attrs small than 2, got " << attrs_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
ValuePtr transpose_a_ptr = input_value_.at(kQbmmInputTransposeX1);
|
||||
if (transpose_a_ptr != nullptr) {
|
||||
transpose_a_ = GetValue<bool>(transpose_a_ptr);
|
||||
|
|
|
@ -521,7 +521,9 @@ Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const l
|
|||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
Status ReshapeInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
|
||||
Status ReshapeInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts) {
|
||||
auto reshape_skip_redis_iter = attrs_.find(SKIP_REDISTRIBUTION);
|
||||
if (reshape_skip_redis_iter != attrs_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(reshape_skip_redis_iter->second);
|
||||
|
|
|
@ -44,7 +44,9 @@ class ReshapeInfo : public OperatorInfo {
|
|||
input_layout_set_flag_(false),
|
||||
output_layout_set_flag_(false) {}
|
||||
~ReshapeInfo() override = default;
|
||||
Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
|
||||
Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts = {},
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts = {}) override;
|
||||
void SetInputLayout(const TensorLayout &input_layout) {
|
||||
input_layout_ = input_layout;
|
||||
input_layout_set_flag_ = true;
|
||||
|
|
|
@ -127,7 +127,9 @@ Status VirtualDatasetInfo::InferTensorMap() {
|
|||
|
||||
Status VirtualDatasetInfo::GetAttrs() { return SUCCESS; }
|
||||
|
||||
Status VirtualDatasetInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
|
||||
Status VirtualDatasetInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts) {
|
||||
repeated_num_in_dev_matrix_right_ = false;
|
||||
if (ParallelContext::GetInstance()->dataset_repeat_dim_right()) {
|
||||
repeated_num_in_dev_matrix_right_ = true;
|
||||
|
|
|
@ -34,7 +34,9 @@ class VirtualDatasetInfo : public OperatorInfo {
|
|||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>()) {}
|
||||
~VirtualDatasetInfo() override = default;
|
||||
Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
|
||||
Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts = {},
|
||||
const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts = {}) override;
|
||||
Status InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
|
||||
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
|
|
|
@ -466,15 +466,16 @@ void SliceParameterObj(const ParameterPtr ¶meter, const TensorLayoutPtr &ten
|
|||
// create python layout obj
|
||||
const auto &device_arrangement = tensor_layout->device_arrangement().array();
|
||||
const auto &tensor_map = tensor_layout->tensor_map().array();
|
||||
auto slice_shape = tensor_layout->slice_shape().array();
|
||||
auto slice_shape = tensor_layout->base_slice_shape().array();
|
||||
int64_t field_size = tensor_layout->get_field_size();
|
||||
bool uniform_split = tensor_layout->uniform_split();
|
||||
std::string opt_shard_group = tensor_layout->opt_shard_group();
|
||||
if (!opt_shard_group.empty()) {
|
||||
slice_shape = tensor_layout->opt_shard_slice_shape();
|
||||
}
|
||||
auto full_shape = tensor_layout->tensor_shape().array();
|
||||
py::tuple layout =
|
||||
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
|
||||
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group, full_shape);
|
||||
|
||||
// Call Python _slice_parameter Fn to slice python parameter obj
|
||||
(void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, py_obj, py::str(phase), layout);
|
||||
|
|
|
@ -1463,7 +1463,7 @@ static std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair
|
|||
}
|
||||
TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[LongToSize(res.second - 1)];
|
||||
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
|
||||
Shape slice_shape = tensor_layout.slice_shape().array();
|
||||
Shape slice_shape = tensor_layout.base_slice_shape().array();
|
||||
|
||||
// generate shard group
|
||||
std::string opt_shard_group;
|
||||
|
@ -1527,7 +1527,7 @@ static void CoverSliceShape(const FuncGraphPtr &root) {
|
|||
if (parameter->has_user_data<TensorLayout>()) {
|
||||
auto param_abstract = parameter->abstract()->Clone();
|
||||
auto tensor_layout = parameter->user_data<TensorLayout>();
|
||||
Shape slice_shape = tensor_layout->slice_shape().array();
|
||||
Shape slice_shape = tensor_layout->base_slice_shape().array();
|
||||
param_abstract->set_shape(std::make_shared<abstract::Shape>(slice_shape));
|
||||
parameter->set_abstract(param_abstract);
|
||||
}
|
||||
|
@ -1652,9 +1652,15 @@ static void ExtractStrategyAndInit(const CNodePtr &cnode, const PrimitivePtr &pr
|
|||
} else {
|
||||
in_strategy = GenerateStandAloneStrategy(op_info->inputs_shape());
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<TensorLayout>> in_tensor_layouts;
|
||||
std::vector<std::shared_ptr<TensorLayout>> out_tensor_layouts;
|
||||
if (ExtractUserConfigLayout(attrs, op_info->inputs_shape(), op_info->outputs_shape(), &in_tensor_layouts,
|
||||
&out_tensor_layouts) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " extract configured layout failed"
|
||||
<< trace::DumpSourceLines(cnode);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(in_strategy);
|
||||
if (op_info->Init(in_strategy, out_strategy) == FAILED) {
|
||||
if (op_info->Init(in_strategy, out_strategy, in_tensor_layouts, out_tensor_layouts) == FAILED) {
|
||||
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed" << trace::DumpSourceLines(cnode);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1969,6 +1969,105 @@ StrategyPtr ExtractStrategy(const ValuePtr &stra) {
|
|||
return strategyPtr;
|
||||
}
|
||||
|
||||
Status GetLayoutFromAttrValue(const ValuePtr &layout_item, std::vector<int64_t> *device_matrix_vector,
|
||||
std::vector<std::vector<int64_t>> *tensor_map_vector) {
|
||||
auto layout_dict_value = layout_item->cast<ValueDictionaryPtr>();
|
||||
if (!layout_dict_value) {
|
||||
MS_LOG(ERROR) << "The layout item configured for node is unreasonable";
|
||||
return FAILED;
|
||||
}
|
||||
auto layout_dict = layout_dict_value->value();
|
||||
ValuePtr device_matrix_value = nullptr;
|
||||
ValuePtr tensor_map_value = nullptr;
|
||||
for (const auto &value_pair : layout_dict) {
|
||||
if ((*value_pair.first) == (*MakeValue<std::string>(DEVICE_MATRIX))) {
|
||||
device_matrix_value = value_pair.second;
|
||||
}
|
||||
if ((*value_pair.first) == (*MakeValue<std::string>(TENSOR_MAP))) {
|
||||
tensor_map_value = value_pair.second;
|
||||
}
|
||||
}
|
||||
if (!device_matrix_value || !tensor_map_value) {
|
||||
MS_LOG(ERROR) << "The layout item configured for node is unreasonable";
|
||||
return FAILED;
|
||||
}
|
||||
*device_matrix_vector = GetValue<std::vector<int64_t>>(device_matrix_value);
|
||||
auto tensor_map_value_tuple = tensor_map_value->cast<ValueTuplePtr>();
|
||||
std::vector<ValuePtr> tensor_map_value_tuple_vector = tensor_map_value_tuple->value();
|
||||
for (const auto &tensor_map_item : tensor_map_value_tuple_vector) {
|
||||
if (tensor_map_item->isa<ValueSequence>()) {
|
||||
auto tensor_map_item_v = GetValue<std::vector<int64_t>>(tensor_map_item);
|
||||
tensor_map_vector->push_back(tensor_map_item_v);
|
||||
continue;
|
||||
}
|
||||
auto tensor_map_item_i = GetValue<int64_t>(tensor_map_item);
|
||||
tensor_map_vector->push_back({tensor_map_item_i});
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ExtractUserConfigLayout(const mindspore::HashMap<std::string, ValuePtr> &prim_attrs, const Shapes &inputs_shape,
|
||||
const Shapes &outputs_shape,
|
||||
std::vector<std::shared_ptr<TensorLayout>> *in_tensor_layouts,
|
||||
std::vector<std::shared_ptr<TensorLayout>> *out_tensor_layouts) {
|
||||
if (prim_attrs.count(IN_LAYOUT) > 0) {
|
||||
auto layout_value = prim_attrs.at(IN_LAYOUT);
|
||||
if (!layout_value->isa<ValueSequence>()) {
|
||||
MS_LOG(ERROR) << "The in_layout configured for node is not a tuple";
|
||||
return FAILED;
|
||||
}
|
||||
auto layout_value_tuple = layout_value->cast<ValueTuplePtr>();
|
||||
std::vector<ValuePtr> layout_value_vector = layout_value_tuple->value();
|
||||
if (inputs_shape.size() != layout_value_vector.size()) {
|
||||
MS_LOG(ERROR) << "The in_layout configured for node is not equal to its input nums";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < layout_value_vector.size(); ++i) {
|
||||
auto layout_item = layout_value_vector[i];
|
||||
std::vector<int64_t> device_matrix_vector;
|
||||
std::vector<std::vector<int64_t>> tensor_map_vector;
|
||||
if (GetLayoutFromAttrValue(layout_item, &device_matrix_vector, &tensor_map_vector) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
auto in_layout = std::make_shared<TensorLayout>();
|
||||
if (in_layout->InitFromExtendVector(device_matrix_vector, tensor_map_vector, inputs_shape[i]) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "The in_layout configured incorretc, device_matrix:" << device_matrix_vector
|
||||
<< ", tensor_map:" << tensor_map_vector;
|
||||
return FAILED;
|
||||
}
|
||||
in_tensor_layouts->push_back(in_layout);
|
||||
}
|
||||
}
|
||||
if (prim_attrs.count(OUT_LAYOUT) > 0) {
|
||||
auto layout_value = prim_attrs.at(OUT_LAYOUT);
|
||||
if (!layout_value->isa<ValueSequence>()) {
|
||||
MS_LOG(EXCEPTION) << "The in_layout configured for node is not a tuple";
|
||||
}
|
||||
auto layout_value_tuple = layout_value->cast<ValueTuplePtr>();
|
||||
std::vector<ValuePtr> layout_value_vector = layout_value_tuple->value();
|
||||
if (outputs_shape.size() != layout_value_vector.size()) {
|
||||
MS_LOG(EXCEPTION) << "The out_layout configured for node is not equal to its output nums";
|
||||
}
|
||||
for (size_t i = 0; i < layout_value_vector.size(); ++i) {
|
||||
auto layout_item = layout_value_vector[i];
|
||||
std::vector<int64_t> device_matrix_vector;
|
||||
std::vector<std::vector<int64_t>> tensor_map_vector;
|
||||
if (GetLayoutFromAttrValue(layout_item, &device_matrix_vector, &tensor_map_vector) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
auto out_layout = std::make_shared<TensorLayout>();
|
||||
if (out_layout->InitFromExtendVector(device_matrix_vector, tensor_map_vector, outputs_shape[i]) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "The out_layout configured incorretc, device_matrix:" << device_matrix_vector
|
||||
<< ", tensor_map:" << tensor_map_vector;
|
||||
return FAILED;
|
||||
}
|
||||
out_tensor_layouts->push_back(out_layout);
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
static bool IsCohesiveNode(const CNodePtr &cnode) {
|
||||
return IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
|
||||
IsPrimitiveCNode(cnode, prim::kPrimDepend) || IsPrimitiveCNode(cnode, prim::kPrimAllGather) ||
|
||||
|
|
|
@ -145,6 +145,10 @@ std::string GetSerialNumberString(size_t number);
|
|||
bool IsIgnoreSplitTensor(const CNodePtr &node, int64_t index);
|
||||
bool MergeConcatSlice(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
|
||||
void UpdateMicroBatchInterleavedStatus(const std::vector<AnfNodePtr> &all_nodes);
|
||||
Status ExtractUserConfigLayout(const mindspore::HashMap<std::string, ValuePtr> &prim_attrs, const Shapes &inputs_shape,
|
||||
const Shapes &outputs_shape,
|
||||
std::vector<std::shared_ptr<TensorLayout>> *in_tensor_layouts,
|
||||
std::vector<std::shared_ptr<TensorLayout>> *out_tensor_layouts);
|
||||
inline bool IsMakeSequence(const AnfNodePtr &node) {
|
||||
return AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST);
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "utils/ms_utils.h"
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/status.h"
|
||||
#include "include/common/utils/parallel_context.h"
|
||||
#include "frontend/parallel/tensor_layout/shape_util.h"
|
||||
|
@ -85,6 +86,91 @@ Status TensorLayout::InitFromVector(const Shape &device_arrangement, const Shape
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
/*
|
||||
* example1:
|
||||
* in_device_arrangement = [8, 2, 4],
|
||||
* in_tensor_map = [[2], [1, 0]],
|
||||
* in_tensor_shape = [512, 1024],
|
||||
* =>
|
||||
* in_device_arrangement = [8, 2, 4],
|
||||
* in_tensor_map = [2, 1, 0],
|
||||
* in_tensor_shape = [512, 2, 512],
|
||||
* example2:
|
||||
* in_device_arrangement = [8, 2, 4],
|
||||
* in_tensor_map = [[1], [0, 2]],
|
||||
* in_tensor_shape = [512, 1024],
|
||||
* =>
|
||||
* in_device_arrangement = [8, 2, 4],
|
||||
* in_tensor_map = [1, 0, 2],
|
||||
* in_tensor_shape = [512, 4, 256],
|
||||
*/
|
||||
Status TensorLayout::InitFromExtendVector(const Shape &device_arrangement, const std::vector<Shape> &tensor_map,
|
||||
const Shape &tensor_shape) {
|
||||
if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
CheckGlobalDeviceManager();
|
||||
auto device_num = g_device_manager->stage_device_num();
|
||||
int64_t device_total =
|
||||
std::accumulate(device_arrangement.begin(), device_arrangement.end(), 1, std::multiplies<int64_t>());
|
||||
if (device_num != device_total) {
|
||||
MS_LOG(ERROR) << "The configured device_matrix " << device_arrangement << " accumulate value " << device_total
|
||||
<< " dose not equal to the device number in one stage " << device_num;
|
||||
return FAILED;
|
||||
}
|
||||
Shape extended_tensor_map;
|
||||
Shape reshaped_tensor_shape;
|
||||
if (tensor_shape.size() != tensor_map.size()) {
|
||||
MS_LOG(ERROR) << "The tensor_shape " << tensor_shape << " dose not have the same size with tensor_map "
|
||||
<< tensor_map;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
size_t not_none_count = 0;
|
||||
for (size_t i = 0; i < tensor_map.size(); ++i) {
|
||||
for (size_t j = 0; j < tensor_map[i].size(); ++j) {
|
||||
extended_tensor_map.push_back(tensor_map[i][j]);
|
||||
if (tensor_map[i][j] > 0) {
|
||||
++not_none_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (not_none_count > device_arrangement.size()) {
|
||||
MS_LOG(ERROR) << "The device_matrix " << device_arrangement
|
||||
<< " length dose not greater equal than the not None size of extended_tensor_map "
|
||||
<< extended_tensor_map;
|
||||
return FAILED;
|
||||
}
|
||||
tensor_shape_before_.Init(tensor_shape);
|
||||
for (size_t i = 0; i < tensor_map.size(); ++i) {
|
||||
if (tensor_map[i].size() == 1) {
|
||||
reshaped_tensor_shape.push_back(tensor_shape[i]);
|
||||
continue;
|
||||
}
|
||||
int64_t accu_shp = 1;
|
||||
for (size_t j = 0; j < tensor_map[i].size() - 1; ++j) {
|
||||
size_t tensor_index = device_arrangement.size() - 1 - static_cast<size_t>(tensor_map[i][j]);
|
||||
auto shard_size = device_arrangement[tensor_index];
|
||||
accu_shp *= shard_size;
|
||||
reshaped_tensor_shape.push_back(shard_size);
|
||||
}
|
||||
auto last_shp = tensor_shape[i] / accu_shp;
|
||||
reshaped_tensor_shape.push_back(last_shp);
|
||||
}
|
||||
if (tensor_map_origin_.Init(extended_tensor_map) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
if (tensor_shape_origin_.Init(reshaped_tensor_shape) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
if (Init(device_arrangement_origin_, tensor_map_origin_, tensor_shape_origin_) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
tensor_map_before_ = tensor_map;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
bool TensorLayout::IsValidTensorLayout() const {
|
||||
if (tensor_map_origin_.GetMaxItem() >= static_cast<int64_t>(device_arrangement_origin_.GetDimSize())) {
|
||||
MS_LOG(ERROR) << "the max element in tensor_map_origin_ must be smaller than device_arrangement_origin_ size!";
|
||||
|
@ -340,6 +426,37 @@ Arrangement TensorLayout::slice_shape() const {
|
|||
}
|
||||
}
|
||||
|
||||
Arrangement TensorLayout::base_slice_shape() const {
|
||||
if (tensor_map_before_.empty()) {
|
||||
return slice_shape();
|
||||
}
|
||||
Shape shape;
|
||||
for (size_t index = 0; index < tensor_map_before_.size(); index++) {
|
||||
auto dim_map = tensor_map_before_[index];
|
||||
int64_t num = tensor_shape_before_.GetDimByIdx(index);
|
||||
int64_t axis_shard = 1;
|
||||
for (const auto &dim : dim_map) {
|
||||
if (dim != -1) {
|
||||
int64_t divisor = device_arrangement_origin_.GetDimByReverseIdx(LongToUlong(dim));
|
||||
axis_shard *= divisor;
|
||||
}
|
||||
}
|
||||
if (num == -1) {
|
||||
shape.push_back(num); // num == -1 means dynamic shape
|
||||
} else {
|
||||
shape.push_back(num / axis_shard);
|
||||
}
|
||||
}
|
||||
|
||||
Arrangement new_tensor_shape;
|
||||
if (new_tensor_shape.Init(shape) == Status::FAILED) {
|
||||
ValuePtr ptr = MakeValue(shape);
|
||||
MS_LOG(EXCEPTION) << "Can't get slice shape when initialize a new shape " << ptr->ToString();
|
||||
} else {
|
||||
return new_tensor_shape;
|
||||
}
|
||||
}
|
||||
|
||||
Shape TensorLayout::shard_strategy() const {
|
||||
Shape ret;
|
||||
for (size_t index = 0; index < tensor_map_.GetDimSize(); index++) {
|
||||
|
@ -445,11 +562,22 @@ TensorLayout TensorLayout::TransferRepeatLayout() const {
|
|||
return repeat;
|
||||
}
|
||||
|
||||
RankList TensorLayout::InferRepeatedGroup() {
|
||||
CheckGlobalDeviceManager();
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
DeviceMatrix dev_matrix(rank, g_device_manager->GetDeviceListInThisStage(), device_arrangement_origin_.array());
|
||||
RankList group_devices;
|
||||
if (dev_matrix.GetDevicesByTensorMap(tensor_map_origin_.array(), &group_devices) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Tensor layout:" << ToString() << " infer repeated group failed.";
|
||||
}
|
||||
return group_devices;
|
||||
}
|
||||
|
||||
// Generate a totally shard tensor slice shape for parallel optimizer
|
||||
Status TensorLayout::GenerateOptShardSliceShape() {
|
||||
MS_LOG(INFO) << "layout for GetOptShardSliceShape is " << StandardToString();
|
||||
Shape dev_max = device_arrangement_.array();
|
||||
Shape tensor_map = tensor_map_.array();
|
||||
|
||||
Shape repeated_dev;
|
||||
for (size_t i = 0; i < dev_max.size(); i++) {
|
||||
if (tensor_map_.GetIndexByValue(static_cast<int64_t>(i)) == MAP_NONE) {
|
||||
|
@ -463,22 +591,17 @@ Status TensorLayout::GenerateOptShardSliceShape() {
|
|||
}
|
||||
int64_t repeated_num =
|
||||
std::accumulate(repeated_dev.begin(), repeated_dev.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
|
||||
int64_t split_num;
|
||||
int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
|
||||
if (optimizer_weight_shard_size != -1 && repeated_num >= optimizer_weight_shard_size) {
|
||||
repeated_num = optimizer_weight_shard_size;
|
||||
}
|
||||
if (tensor_map[0] == MAP_NONE) {
|
||||
split_num = repeated_num;
|
||||
} else {
|
||||
split_num = dev_max[dev_max.size() - 1 - static_cast<size_t>(tensor_map[0])] * repeated_num;
|
||||
}
|
||||
if (tensor_shape_.array()[0] % split_num != 0) {
|
||||
|
||||
Shape origin_slice_shape = base_slice_shape().array();
|
||||
if (origin_slice_shape[0] % repeated_num != 0) {
|
||||
MS_LOG(INFO) << "Tensor could not be shard on the first dimension.";
|
||||
return Status::FAILED;
|
||||
}
|
||||
Shape origin_slice_shape = slice_shape().array();
|
||||
origin_slice_shape[0] = tensor_shape_.array()[0] / split_num;
|
||||
origin_slice_shape[0] = origin_slice_shape[0] / repeated_num;
|
||||
opt_shard_slice_shape_ = origin_slice_shape;
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
|
|
@ -41,7 +41,8 @@ class TensorLayout {
|
|||
std::string OriginToString() const;
|
||||
Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape);
|
||||
Status InitFromVector(const Shape &device_arrangement, const Shape &tensor_map, const Shape &tensor_shape);
|
||||
|
||||
Status InitFromExtendVector(const Shape &device_arrangement, const std::vector<Shape> &tensor_map,
|
||||
const Shape &tensor_shape);
|
||||
bool skip_redistribution() const { return skip_redistribution_; }
|
||||
|
||||
void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; }
|
||||
|
@ -96,6 +97,8 @@ class TensorLayout {
|
|||
|
||||
Arrangement slice_shape() const;
|
||||
|
||||
Arrangement base_slice_shape() const;
|
||||
|
||||
Shape shard_strategy() const;
|
||||
|
||||
Status UpdateTensorMap(size_t index, int64_t value);
|
||||
|
@ -128,6 +131,13 @@ class TensorLayout {
|
|||
|
||||
bool is_shared_param() const { return is_shared_param_; }
|
||||
|
||||
void set_tensor_shape_before(const Shape &tensor_shape_before) { tensor_shape_before_.Init(tensor_shape_before); }
|
||||
|
||||
RankList InferRepeatedGroup();
|
||||
|
||||
Arrangement tensor_shape_before() { return tensor_shape_before_; }
|
||||
|
||||
std::vector<Shape> tensor_map_before() { return tensor_map_before_; }
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "TLayout";
|
||||
|
||||
|
@ -146,8 +156,10 @@ class TensorLayout {
|
|||
Arrangement tensor_shape_origin_;
|
||||
Arrangement device_arrangement_;
|
||||
Arrangement tensor_shape_;
|
||||
Arrangement tensor_shape_before_;
|
||||
Map tensor_map_;
|
||||
Map tensor_map_origin_;
|
||||
std::vector<Shape> tensor_map_before_;
|
||||
bool skip_redistribution_ = false;
|
||||
bool uniform_split_ = true;
|
||||
bool layout_transfer_ = false;
|
||||
|
|
|
@ -159,6 +159,19 @@ Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const
|
|||
}
|
||||
}
|
||||
|
||||
if (from_origin_.base_slice_shape().array() != from_origin_.slice_shape().array()) {
|
||||
reshape_flag_ = true;
|
||||
constructor.UpdateTensorShape(from_origin_.base_slice_shape().array());
|
||||
Arrangement shape = from_origin_.slice_shape();
|
||||
MS_LOG(DEBUG) << "reshape " << shape.ToString();
|
||||
if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
|
||||
return Status::FAILED;
|
||||
} else {
|
||||
(void)operator_vector->insert(operator_vector->cbegin(), constructor.GetOperator());
|
||||
(void)output_info_vector->insert(output_info_vector->cbegin(), std::make_pair(false, 0));
|
||||
}
|
||||
}
|
||||
|
||||
if (to_origin_.slice_shape().array() != to_layout.slice_shape().array()) {
|
||||
reshape_flag_ = true;
|
||||
constructor.UpdateTensorShape(to_layout.slice_shape().array());
|
||||
|
@ -171,6 +184,19 @@ Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const
|
|||
(void)output_info_vector->insert(output_info_vector->cend(), std::make_pair(false, 0));
|
||||
}
|
||||
}
|
||||
|
||||
if (to_origin_.slice_shape().array() != to_origin_.base_slice_shape().array()) {
|
||||
reshape_flag_ = true;
|
||||
constructor.UpdateTensorShape(to_origin_.slice_shape().array());
|
||||
Arrangement shape = to_origin_.base_slice_shape();
|
||||
MS_LOG(DEBUG) << "step_parallel to reshape " << shape.ToString();
|
||||
if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
|
||||
return Status::FAILED;
|
||||
} else {
|
||||
(void)operator_vector->insert(operator_vector->cend(), constructor.GetOperator());
|
||||
(void)output_info_vector->insert(output_info_vector->cend(), std::make_pair(false, 0));
|
||||
}
|
||||
}
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -286,9 +286,11 @@ RedistributionOpListPtr TensorTransform::OptimizeTensorRedistributionOperatorLis
|
|||
}
|
||||
auto axis = transform_op_list[i].second.back();
|
||||
if (axis == 0) {
|
||||
current_allgather_pos_in_origin_list += kSize3;
|
||||
continue;
|
||||
}
|
||||
if (i == transform_op_list.size() - 1 || transform_op_list[i + 1].first != RESHAPE) {
|
||||
current_allgather_pos_in_origin_list += kSize3;
|
||||
continue;
|
||||
}
|
||||
auto src_shape = shape_list[i];
|
||||
|
@ -305,6 +307,7 @@ RedistributionOpListPtr TensorTransform::OptimizeTensorRedistributionOperatorLis
|
|||
MS_LOG(INFO) << "src_shape:" << src_shape << ", new_src_shape:" << new_src_shape << ", axis:" << axis
|
||||
<< ", new_axis:" << new_axis;
|
||||
if (new_axis != 0) {
|
||||
current_allgather_pos_in_origin_list += kSize3;
|
||||
continue;
|
||||
}
|
||||
left_reshape_op_list[current_allgather_pos_in_origin_list] = new_src_shape;
|
||||
|
|
|
@ -67,6 +67,9 @@ if(ENABLE_D)
|
|||
list(APPEND MSLIB_INFER_SRC
|
||||
"${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.cc"
|
||||
"${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc"
|
||||
"${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/group_manager.cc"
|
||||
"${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/device_manager.cc"
|
||||
"${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/device_matrix.cc"
|
||||
"${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/tensor_layout/array.cc"
|
||||
"${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/tensor_layout/map.cc"
|
||||
"${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/tensor_layout/arrangement.cc"
|
||||
|
|
|
@ -29,7 +29,8 @@ from mindspore.context import GRAPH_MODE, PYNATIVE_MODE, set_context, get_contex
|
|||
from mindspore.version import __version__
|
||||
from mindspore.profiler import Profiler, EnvProfiler
|
||||
from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters, \
|
||||
rank_list_for_transform, transform_checkpoint_by_rank, transform_checkpoints, merge_pipeline_strategys, shard
|
||||
rank_list_for_transform, transform_checkpoint_by_rank, transform_checkpoints, merge_pipeline_strategys, shard, \
|
||||
Layout
|
||||
from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType
|
||||
from mindspore.safeguard import obfuscate_ckpt, load_obf_params_into_net
|
||||
from mindspore._check_jit_forbidden_api import get_obj_module_and_name_info, is_jit_forbidden_module, \
|
||||
|
|
|
@ -1295,8 +1295,11 @@ class _GetTensorSlice(PrimitiveWithInfer):
|
|||
"""Initialize _GetTensorSlice."""
|
||||
self.add_prim_attr('order_enforce_skip', True)
|
||||
|
||||
def infer_value(self, x, dev_mat, tensor_map):
|
||||
def infer_value(self, x, dev_mat, tensor_map, slice_shape, full_shape):
|
||||
from mindspore.parallel._tensor import _load_tensor
|
||||
validator.check_value_type("dev_mat", dev_mat, [tuple], self.name)
|
||||
validator.check_value_type("tensor_map", tensor_map, [tuple], self.name)
|
||||
return Tensor(_load_tensor(x, dev_mat, tensor_map), x.dtype)
|
||||
tensor_slice = _load_tensor(x, dev_mat, tensor_map, full_shape)
|
||||
if tensor_slice.shape != slice_shape:
|
||||
tensor_slice = tensor_slice.reshape(slice_shape)
|
||||
return Tensor(tensor_slice)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -17,11 +17,13 @@
|
|||
import functools
|
||||
import inspect
|
||||
import copy
|
||||
import numpy as np
|
||||
from mindspore.common.api import _wrap_func
|
||||
from mindspore.log import _LogActionOnce
|
||||
from mindspore import context, log as logger
|
||||
from mindspore.parallel._utils import _is_in_auto_parallel_mode, _is_in_data_parallel_mode, _is_in_hybrid_parallel_mode
|
||||
from mindspore.parallel._ps_context import _is_ps_mode, _is_role_sched
|
||||
from mindspore.parallel.shard import Layout
|
||||
from mindspore.common.api import _pynative_executor
|
||||
from mindspore.common._stub_tensor import _convert_stub
|
||||
from mindspore._c_expression import Primitive_, PrimitiveFunction_, prim_type, typing
|
||||
|
@ -155,6 +157,63 @@ class Primitive(Primitive_):
|
|||
cloned.set_prim_instance_name(self.instance_name)
|
||||
return cloned
|
||||
|
||||
def _check_shard_strategy(self, strategy, log_info):
|
||||
"""Check shard strategy is validate or not"""
|
||||
is_layout = []
|
||||
if not isinstance(strategy, tuple):
|
||||
raise TypeError(f'{log_info} must be tuple type, but got:{type(strategy)}')
|
||||
for in_ele in strategy:
|
||||
if not isinstance(in_ele, tuple) and not isinstance(in_ele, Layout):
|
||||
raise TypeError(f'The element of strategy must be tuple/Layout type, but got:{type(in_ele)}')
|
||||
if isinstance(in_ele, tuple):
|
||||
for in_value in in_ele:
|
||||
if not isinstance(in_value, int):
|
||||
raise TypeError(f'The {log_info}: {strategy} of {self.name} is not valid,'
|
||||
f' the value of strategy must be int type, but got:{type(in_value)}')
|
||||
is_layout.append(False)
|
||||
continue
|
||||
is_layout.append(True)
|
||||
if not is_layout:
|
||||
np_is_layout = np.array(is_layout)
|
||||
if not (np_is_layout == np_is_layout[0]).all():
|
||||
raise TypeError(f'{log_info} item must be all tuple type or all Layout type.')
|
||||
return np.array(is_layout)
|
||||
|
||||
def _extract_layout_value(self, layout, log_info):
|
||||
"""Extract parallel layout value"""
|
||||
layout_value = None
|
||||
if layout is not None:
|
||||
if not isinstance(layout, tuple):
|
||||
raise TypeError(f'{log_info} must be tuple type, but got:{type(layout)}')
|
||||
layout_value = ()
|
||||
for in_ele in layout:
|
||||
if not isinstance(in_ele, Layout):
|
||||
raise TypeError(f"The {log_info} item should be a object of class Layout.")
|
||||
layout_value += (in_ele.to_dict(),)
|
||||
return layout_value
|
||||
|
||||
def _check_shard_strategy_in_out_match(self, in_strategy, out_strategy):
|
||||
"""Check shard in_strategy and out_strategy"""
|
||||
if in_strategy is None and out_strategy is not None:
|
||||
raise ValueError(f'The out_strategy of {self.name} is {out_strategy}, need to set in_strategy,'
|
||||
f' but got none')
|
||||
if not _is_in_auto_parallel_mode():
|
||||
mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if in_strategy is not None:
|
||||
logger.warning(f"The in_strategy/in_layout of the operator in your network "
|
||||
f"will not take effect in {mode} mode. "
|
||||
f"This means the the shard function called in the network is ignored. \n"
|
||||
f"If you want to enable it, please use semi auto or auto parallel mode by "
|
||||
f"context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL "
|
||||
f"or context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL)")
|
||||
if out_strategy is not None:
|
||||
logger.warning(f"The out_strategy/out_layout of the operator in your network "
|
||||
f"will not take effect in {mode} mode."
|
||||
f" This means the the shard function called in the network is ignored. \n"
|
||||
f"If you want to enable it, please use semi auto or auto parallel mode by "
|
||||
f"context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL "
|
||||
f"or context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL)")
|
||||
|
||||
def del_prim_attr(self, name):
|
||||
"""
|
||||
Delete primitive attribute.
|
||||
|
@ -212,50 +271,52 @@ class Primitive(Primitive_):
|
|||
>>> add = ops.Add()
|
||||
>>> print(add.shard(((1, 1), (1, 1))))
|
||||
Prim[Add]<in_strategy=((1, 1), (1, 1)), out_strategy=None>
|
||||
>>> # using layout
|
||||
>>> from mindspore import Layout
|
||||
>>> layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
>>> layout_tuple = (layout("dp", "sp"), layout("sp", "mp"))
|
||||
>>> from mindspore import ops
|
||||
>>> matmul = ops.MatMul()
|
||||
>>> print(matmul.shard(layout_tuple))
|
||||
Prim[MatMul]<in_layout=({'device_matrix': (2, 2, 2), 'tensor_map': (2, 1)},
|
||||
{'device_matrix': (2, 2, 2), 'tensor_map': (1, 0)})>
|
||||
>>> # using layout with None
|
||||
>>> from mindspore import Layout
|
||||
>>> layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
>>> layout_tuple = (layout("dp", "sp"), layout("sp", "None")) # "None" means the axis would not be split
|
||||
>>> from mindspore import ops
|
||||
>>> matmul = ops.MatMul()
|
||||
>>> print(matmul.shard(layout_tuple))
|
||||
Prim[MatMul]<in_layout=({'device_matrix': (2, 2, 2), 'tensor_map': (2, 1)},
|
||||
{'device_matrix': (2, 2, 2), 'tensor_map': (1, -1)})>
|
||||
"""
|
||||
mode = context.get_auto_parallel_context("parallel_mode")
|
||||
in_is_layout = None
|
||||
out_is_layout = None
|
||||
if in_strategy is not None:
|
||||
if not isinstance(in_strategy, tuple):
|
||||
raise TypeError(f'in_strategy must be tuple type, but got:{type(in_strategy)}')
|
||||
for in_ele in in_strategy:
|
||||
if not isinstance(in_ele, tuple):
|
||||
raise TypeError(f'The element of strategy must be tuple type, but got:{type(in_ele)}')
|
||||
for in_value in in_ele:
|
||||
if not isinstance(in_value, int):
|
||||
raise TypeError(f'The in_strategy: {in_strategy} of {self.name} is not valid,'
|
||||
f' the value of strategy must be int type, but got:{type(in_value)}')
|
||||
in_is_layout = self._check_shard_strategy(in_strategy, "in_strategy")
|
||||
|
||||
if out_strategy is not None:
|
||||
if not isinstance(out_strategy, tuple):
|
||||
raise TypeError(f'out strategy must be tuple type, but got:{type(out_strategy)}')
|
||||
for out_ele in out_strategy:
|
||||
if not isinstance(out_ele, tuple):
|
||||
raise TypeError(f'The element of strategy must be tuple type, but got:{type(out_ele)}')
|
||||
for out_value in out_ele:
|
||||
if not isinstance(out_value, int):
|
||||
raise TypeError(f'The in_strategy: {out_strategy} of {self.name} is not valid,'
|
||||
f' the value of strategy must be int type, but got:{type(out_value)}')
|
||||
out_is_layout = self._check_shard_strategy(out_strategy, "out_strategy")
|
||||
self._check_shard_strategy_in_out_match(in_strategy, out_strategy)
|
||||
if in_is_layout is not None and out_is_layout is not None and in_is_layout[0] != out_is_layout[0]:
|
||||
raise ValueError(f'The in_strategy type must equal to the out_strategy type, '
|
||||
f'one using tuple(tuple) and the other using tuple(Layout) is not allowed.')
|
||||
in_layout_value = None
|
||||
out_layout_value = None
|
||||
if in_is_layout is not None and in_is_layout[0]:
|
||||
in_layout_value = self._extract_layout_value(in_strategy, "in_strategy")
|
||||
if out_is_layout is not None and out_is_layout[0]:
|
||||
out_layout_value = self._extract_layout_value(out_strategy, "out_strategy")
|
||||
|
||||
if in_strategy is None and out_strategy is not None:
|
||||
raise ValueError(f'The out_strategy of {self.name} is {out_strategy}, need to set in_strategy,'
|
||||
f' but got none')
|
||||
|
||||
if not _is_in_auto_parallel_mode():
|
||||
if in_strategy is not None:
|
||||
logger.warning(f"The in_strategy of the operator in your network will not take effect in {mode} mode. "
|
||||
f"This means the the shard function called in the network is ignored. \n"
|
||||
f"If you want to enable it, please use semi auto or auto parallel mode by "
|
||||
f"context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL "
|
||||
f"or context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL)")
|
||||
if out_strategy is not None:
|
||||
logger.warning(f"The out_strategy of the operator in your network will not take effect in {mode} mode."
|
||||
f" This means the the shard function called in the network is ignored. \n"
|
||||
f"If you want to enable it, please use semi auto or auto parallel mode by "
|
||||
f"context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL "
|
||||
f"or context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL)")
|
||||
|
||||
self.add_prim_attr("in_strategy", in_strategy)
|
||||
self.add_prim_attr("out_strategy", out_strategy)
|
||||
if in_is_layout is not None and not in_is_layout[0]:
|
||||
self.add_prim_attr("in_strategy", in_strategy)
|
||||
if out_is_layout is not None and not out_is_layout[0]:
|
||||
self.add_prim_attr("out_strategy", out_strategy)
|
||||
if in_layout_value:
|
||||
self.add_prim_attr("in_layout", in_layout_value)
|
||||
if out_layout_value:
|
||||
self.add_prim_attr("out_layout", out_layout_value)
|
||||
return self
|
||||
|
||||
def set_prim_instance_name(self, instance_name):
|
||||
|
|
|
@ -19,8 +19,8 @@ from mindspore.parallel.algo_parameter_config import get_algo_parameters, reset_
|
|||
set_algo_parameters
|
||||
from mindspore.parallel.checkpoint_transform import rank_list_for_transform, transform_checkpoint_by_rank, \
|
||||
transform_checkpoints, merge_pipeline_strategys, sync_pipeline_shared_parameters
|
||||
from mindspore.parallel.shard import shard
|
||||
from mindspore.parallel.shard import shard, Layout
|
||||
|
||||
__all__ = ["set_algo_parameters", "reset_algo_parameters", "get_algo_parameters", "rank_list_for_transform",
|
||||
"transform_checkpoint_by_rank", "transform_checkpoints", "merge_pipeline_strategys", "shard",
|
||||
"sync_pipeline_shared_parameters"]
|
||||
"sync_pipeline_shared_parameters", "Layout"]
|
||||
|
|
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations.comm_ops import AllGather
|
||||
from mindspore.communication import GlobalComm
|
||||
from mindspore.common import jit
|
||||
|
@ -29,16 +30,18 @@ class AllGatherCell(Cell):
|
|||
Allgather cell, used in model parallel scenario.
|
||||
To allgather the selected parameter slice from each device.
|
||||
"""
|
||||
def __init__(self, group):
|
||||
def __init__(self, group, do_reshape, after_reshape_slice_shape):
|
||||
super(AllGatherCell, self).__init__(auto_prefix=False)
|
||||
|
||||
self.allgather = AllGather(group)
|
||||
self.do_reshape = do_reshape
|
||||
self.after_reshape_slice_shape = tuple(after_reshape_slice_shape)
|
||||
self.add_flags(skip_auto_parallel_compile=True)
|
||||
|
||||
@jit()
|
||||
def construct(self, x):
|
||||
if self.do_reshape:
|
||||
x = P.Reshape()(x, self.after_reshape_slice_shape)
|
||||
x = self.allgather(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
@ -51,29 +54,33 @@ class SaveOptShardCkptCell(Cell):
|
|||
Note:
|
||||
This could be optimized later with less communication consumption.
|
||||
"""
|
||||
def __init__(self, group):
|
||||
def __init__(self, group, do_reshape, after_reshape_slice_shape):
|
||||
super(SaveOptShardCkptCell, self).__init__(auto_prefix=False)
|
||||
self.allgather1 = AllGather(group)
|
||||
self.allgather2 = AllGather()
|
||||
self.do_reshape = do_reshape
|
||||
self.after_reshape_slice_shape = tuple(after_reshape_slice_shape)
|
||||
self.add_flags(skip_auto_parallel_compile=True)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.allgather1(x)
|
||||
if self.do_reshape:
|
||||
x = P.Reshape()(x, self.after_reshape_slice_shape)
|
||||
x = self.allgather2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def get_allgather_cell(group, need_merge_twice=False):
|
||||
def get_allgather_cell(group, need_merge_twice=False, do_reshape=False, after_reshape_slice_shape=()):
|
||||
"""Get AllGatherCell object."""
|
||||
global _ALLGATHER_CELL
|
||||
if need_merge_twice:
|
||||
_ALLGATHER_CELL = SaveOptShardCkptCell(group)
|
||||
_ALLGATHER_CELL = SaveOptShardCkptCell(group, do_reshape, after_reshape_slice_shape)
|
||||
else:
|
||||
if group:
|
||||
_ALLGATHER_CELL = AllGatherCell(group)
|
||||
_ALLGATHER_CELL = AllGatherCell(group, do_reshape, after_reshape_slice_shape)
|
||||
else:
|
||||
_ALLGATHER_CELL = AllGatherCell(GlobalComm.WORLD_COMM_GROUP)
|
||||
_ALLGATHER_CELL = AllGatherCell(GlobalComm.WORLD_COMM_GROUP, do_reshape, after_reshape_slice_shape)
|
||||
return _ALLGATHER_CELL
|
||||
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from __future__ import division
|
||||
from __future__ import absolute_import
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
@ -198,7 +199,7 @@ def _get_slice_index(dev_mat, tensor_map, opt_shard_group):
|
|||
return tensor_slice_index
|
||||
|
||||
|
||||
def _load_tensor(tensor, dev_mat, tensor_map, rank_id=-1):
|
||||
def _load_tensor(tensor, dev_mat, tensor_map, full_shape, rank_id=-1):
|
||||
"""
|
||||
Get the tensor slice of the local device by the device matrix and the tensor map
|
||||
|
||||
|
@ -214,7 +215,8 @@ def _load_tensor(tensor, dev_mat, tensor_map, rank_id=-1):
|
|||
>>> tensor = Tensor(np.ones([32, 32]))
|
||||
>>> dev_mat = [2, 4]
|
||||
>>> tensor_map = [1, -1]
|
||||
>>> tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
|
||||
>>> full_shape = [32, 32]
|
||||
>>> tensor_slice = _load_tensor(tensor, dev_mat, tensor_map, full_shape)
|
||||
"""
|
||||
if rank_id == -1:
|
||||
rank = get_rank()
|
||||
|
@ -227,6 +229,7 @@ def _load_tensor(tensor, dev_mat, tensor_map, rank_id=-1):
|
|||
cpu_cast = Cast().set_device("CPU")
|
||||
tensor = cpu_cast(tensor, mstype.float32)
|
||||
np_tensor = tensor.asnumpy()
|
||||
np_tensor = np_tensor.reshape(full_shape)
|
||||
np_tensor_list = _chunk_tensor_by_strategy(np_tensor, tensor_strategy)
|
||||
np_tensor_slice = np_tensor_list[int(tensor_slice_index)]
|
||||
return np_tensor_slice
|
||||
|
@ -249,21 +252,29 @@ def _load_tensor_by_layout(tensor, layout, rank_id):
|
|||
"""
|
||||
if not isinstance(layout, tuple):
|
||||
raise TypeError("The layout should be tuple! layout is {}".format(layout))
|
||||
if len(layout) < 6:
|
||||
raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
|
||||
if len(layout) < 7:
|
||||
raise ValueError("The length of layout must be larger than 6! layout is {}".format(layout))
|
||||
dev_mat = layout[0]
|
||||
tensor_map = layout[1]
|
||||
slice_shape = layout[2]
|
||||
if not tensor_map:
|
||||
return tensor
|
||||
uniform_split = layout[4]
|
||||
group = layout[5]
|
||||
full_shape = layout[6]
|
||||
if uniform_split == 0:
|
||||
raise RuntimeError("The load tensor only support uniform split now")
|
||||
tensor_slice = _load_tensor(tensor, dev_mat, tensor_map, rank_id)
|
||||
tensor_slice = _load_tensor(tensor, dev_mat, tensor_map, full_shape, rank_id)
|
||||
if tensor_slice.shape != slice_shape and not group:
|
||||
tensor_slice = tensor_slice.reshape(slice_shape)
|
||||
if group:
|
||||
# get a totally shard tensor slice for parallel optimizer
|
||||
rank = get_rank(group)
|
||||
size = get_group_size(group)
|
||||
if tensor_slice.shape != slice_shape and slice_shape:
|
||||
slice_shape_extend = copy.deepcopy(slice_shape)
|
||||
slice_shape_extend[0] = slice_shape[0] * size
|
||||
tensor_slice = tensor_slice.reshape(slice_shape_extend)
|
||||
tensor_slice = np.split(tensor_slice, size)[rank]
|
||||
return Tensor(tensor_slice, tensor.dtype)
|
||||
|
||||
|
|
|
@ -345,13 +345,13 @@ def _sync_params(name, param, layout):
|
|||
ms.log.warning("The layout dict does not contain the pipeline_shared_param info %s", name)
|
||||
return
|
||||
|
||||
pipeline_shared = layout[6]
|
||||
pipeline_shared = layout[8]
|
||||
if not pipeline_shared:
|
||||
return
|
||||
|
||||
is_send = layout[7]
|
||||
peer_rank = layout[8]
|
||||
sr_tag = layout[9]
|
||||
is_send = layout[9]
|
||||
peer_rank = layout[10]
|
||||
sr_tag = layout[11]
|
||||
|
||||
class SharedParameterSyncCell(ms.nn.Cell):
|
||||
"""synchronize cell"""
|
||||
|
|
|
@ -14,11 +14,86 @@
|
|||
# ============================================================================
|
||||
"""pynative shard"""
|
||||
|
||||
import copy
|
||||
import mindspore as ms
|
||||
from mindspore import log as logger
|
||||
from mindspore._c_expression import Shard_
|
||||
|
||||
|
||||
class Layout():
|
||||
"""
|
||||
Parallel layout describes the detailed sharding information.
|
||||
|
||||
Note:
|
||||
It is valid only in semi auto parallel or auto parallel mode.
|
||||
|
||||
Args:
|
||||
device_matrix (tuple): Describe the shape of devices arrangement, its element type is int.
|
||||
alias_name (tuple): The alias name for each axis of device_matrix, its length shoits element type is string.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Layout
|
||||
>>> layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
>>> layout0 = layout("dp", "mp")
|
||||
>>> print(layout0.to_dict())
|
||||
{"device_matrix": (2, 2, 2), "tensor_map": (2, 0)}
|
||||
"""
|
||||
|
||||
def __init__(self, device_matrix, alias_name):
|
||||
if not isinstance(device_matrix, tuple):
|
||||
raise TypeError(f'device_shape must be tuple type, but got:{type(device_matrix)}')
|
||||
if not isinstance(alias_name, tuple):
|
||||
raise TypeError(f'alias_name must be tuple type, but got:{type(alias_name)}')
|
||||
if len(device_matrix) != len(alias_name):
|
||||
raise ValueError(f'device_matrix length should be equal to alias_name length')
|
||||
for in_ele in device_matrix:
|
||||
if not isinstance(in_ele, int):
|
||||
raise TypeError(f'The element of device_matrix must be int type, but got:{type(in_ele)}')
|
||||
for in_ele in alias_name:
|
||||
if not isinstance(in_ele, str):
|
||||
raise TypeError(f'The element of alias_name must be str type, but got:{type(in_ele)}')
|
||||
if in_ele == "None":
|
||||
raise ValueError(f"The element of alias_name can not set 'None', because 'None' means no sharding.")
|
||||
if len(set(alias_name)) != len(alias_name):
|
||||
raise ValueError(f'Each element of alias_name {alias_name} should be different')
|
||||
self._device_shape = device_matrix
|
||||
self._alias_name = alias_name
|
||||
self._tensor_map = None
|
||||
|
||||
def __call__(self, *tensor_map):
|
||||
self._tensor_map = ()
|
||||
for ele in tensor_map:
|
||||
if isinstance(ele, tuple):
|
||||
map = ()
|
||||
for item in ele:
|
||||
if item == "None":
|
||||
map += (-1,)
|
||||
continue
|
||||
if item not in self._alias_name:
|
||||
raise ValueError(f'The axis {item} is not found in {self._alias_name}')
|
||||
map += (len(self._alias_name) - 1 - self._alias_name.index(item),)
|
||||
self._tensor_map += (map,)
|
||||
continue
|
||||
if ele == "None":
|
||||
self._tensor_map += (-1,)
|
||||
continue
|
||||
if ele not in self._alias_name:
|
||||
raise ValueError(f'The axis {ele} is not found in {self._alias_name}')
|
||||
self._tensor_map += (len(self._alias_name) - 1 - self._alias_name.index(ele),)
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Transform layout to a dictionary.
|
||||
"""
|
||||
if self._device_shape is None:
|
||||
raise ValueError("The device_shape of layout is None")
|
||||
if self._tensor_map is None:
|
||||
raise ValueError("The tensor_map of layout is None")
|
||||
return {"device_matrix": self._device_shape, "tensor_map": self._tensor_map}
|
||||
|
||||
|
||||
|
||||
class Shard(Shard_):
|
||||
"""Shard operation"""
|
||||
|
||||
|
|
|
@ -1451,6 +1451,11 @@ def _save_graph(network, file_name):
|
|||
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
f.write(graph_pb)
|
||||
|
||||
def _reshape_tensor(tensor, dst_shape):
|
||||
"""reshape tensor to dst shape"""
|
||||
np_tensor = tensor.asnumpy()
|
||||
np_tensor = np_tensor.reshape(dst_shape)
|
||||
return Tensor(np_tensor, tensor.dtype)
|
||||
|
||||
def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, integrated_save):
|
||||
"""
|
||||
|
@ -1465,7 +1470,7 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|||
Tensor, the combined tensor which with the whole data value.
|
||||
"""
|
||||
layout = parameter_layout_dict[param_name]
|
||||
if len(layout) < 6:
|
||||
if len(layout) < 8:
|
||||
logger.info("The layout dict does not contain the key %s", param_name)
|
||||
return param_data
|
||||
|
||||
|
@ -1473,6 +1478,13 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|||
tensor_map = layout[1]
|
||||
uniform_split = layout[4]
|
||||
opt_shard_group = layout[5]
|
||||
before_reshape_slice_shape = layout[2]
|
||||
before_reshape_full_shape = layout[6]
|
||||
after_reshape_slice_shape = layout[7]
|
||||
do_reshape = False
|
||||
if before_reshape_full_shape and after_reshape_slice_shape\
|
||||
and after_reshape_slice_shape != before_reshape_slice_shape:
|
||||
do_reshape = True
|
||||
|
||||
allgather_net = None
|
||||
mp_weight = False
|
||||
|
@ -1494,17 +1506,22 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|||
# while any dim is not equal to -1, means param is split and needs to be merged
|
||||
# pipeline parallel need to be supported here later
|
||||
if mp_weight:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group))
|
||||
allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group), do_reshape,
|
||||
tuple(after_reshape_slice_shape))
|
||||
object.__setattr__(allgather_net, "keep_input_unchanged", True)
|
||||
elif opt_shard_group:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
|
||||
tuple(after_reshape_slice_shape))
|
||||
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
|
||||
tuple(after_reshape_slice_shape))
|
||||
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
||||
if allgather_net:
|
||||
param_data = allgather_net(param_data)
|
||||
if mp_weight and integrated_save:
|
||||
param_data = _reshape_param_data(param_data, dev_mat, tensor_map)
|
||||
if do_reshape:
|
||||
param_data = _reshape_tensor(param_data, before_reshape_full_shape)
|
||||
return param_data
|
||||
|
||||
|
||||
|
|
|
@ -27,6 +27,11 @@ def setup_function():
|
|||
|
||||
|
||||
def test_get_parameter_layout():
|
||||
"""
|
||||
Feature: test get parameter layout
|
||||
Description: test get parameter layout.
|
||||
Expectation: compile success
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, weight):
|
||||
super().__init__()
|
||||
|
@ -53,8 +58,10 @@ def test_get_parameter_layout():
|
|||
net.set_train()
|
||||
exe = me._cell_graph_executor
|
||||
exe.compile(net, x, phase='train')
|
||||
x_layout = ([8], [0, -1], [32, 32], 0, True, '') # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
weight_layout = ([2, 4], [0, -1], [16, 32], 0, True, '') # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
# device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
x_layout = ([8], [0, -1], [32, 32], 0, True, '')
|
||||
# device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
weight_layout = ([2, 4], [0, -1], [16, 32], 0, True, '')
|
||||
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
||||
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
|
||||
assert net.parameter_layout_dict["x"][0:6] == expect_dict["x"]
|
||||
|
|
|
@ -0,0 +1,282 @@
|
|||
# Copyright 2024 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.parallel.shard import Layout
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
from parallel.utils.utils import ParallelValidator
|
||||
|
||||
|
||||
def setup_function():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, y):
|
||||
predict = self.network(y)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, y):
|
||||
return grad_all(self.network)(y)
|
||||
|
||||
|
||||
def compile_net(net, input_x):
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
phase, _ = _cell_graph_executor.compile(net, input_x)
|
||||
return phase
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, weight, in_layout, out_layout=None):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul().shard(in_strategy=in_layout, out_strategy=out_layout)
|
||||
self.relu = P.ReLU()
|
||||
self.w = Parameter(weight, "w1")
|
||||
|
||||
def construct(self, y):
|
||||
out1 = self.matmul1(y, self.w)
|
||||
out2 = self.relu(out1)
|
||||
out = out1 + out2
|
||||
return out
|
||||
|
||||
class Net1(nn.Cell):
|
||||
def __init__(self, weight, in_layout, out_layout=None):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul().shard(in_strategy=in_layout, out_strategy=out_layout)
|
||||
self.relu = P.ReLU()
|
||||
self.w = Parameter(weight, "w1")
|
||||
|
||||
def construct(self, y):
|
||||
y = self.relu(y)
|
||||
out1 = self.matmul1(y, self.w)
|
||||
return out1
|
||||
|
||||
x = Tensor(np.ones([1024, 1024]), dtype=ms.float32)
|
||||
w = Tensor(np.ones([1024, 1024]), dtype=ms.float32)
|
||||
|
||||
|
||||
def test_layout_extend_base():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout("dp", "sp"), layout("sp", "mp"))
|
||||
net = Net(w, layout1)
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('w1', [512, 512])
|
||||
|
||||
|
||||
def test_layout_extend_base_reduce_scatter():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8.
|
||||
Expectation: compile success, forward reduce_scatter
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout("dp", "sp"), layout("sp", "mp"))
|
||||
out_layout = (layout(("dp", "sp"), "mp"),)
|
||||
net = Net(w, layout1, out_layout)
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('w1', [512, 512])
|
||||
assert validator.check_node_inputs_has('ReduceScatter-0', ['MatMul'])
|
||||
|
||||
def test_layout_extend_batch_multi_shard():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 16, batch dim multi shard.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
layout = Layout((2, 2, 2, 2), ("dp", "sp", "vp", "mp"))
|
||||
layout1 = (layout(("dp", "mp"), "sp"), layout("sp", "vp"))
|
||||
net = Net(w, layout1)
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('w1', [512, 512])
|
||||
|
||||
def test_layout_extend_batch_multi_shard_reduce_scatter():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 16, batch dim multi shard.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
layout = Layout((2, 2, 2, 2), ("dp", "sp", "vp", "mp"))
|
||||
layout1 = (layout(("dp", "mp"), "sp"), layout("sp", "vp"))
|
||||
out_layout = (layout(("dp", "mp", "sp"), "vp"),)
|
||||
net = Net(w, layout1, out_layout)
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('w1', [512, 512])
|
||||
assert validator.check_node_inputs_has('ReduceScatter-0', ['MatMul'])
|
||||
|
||||
def test_layout_extend_reduce_axis_multi_shard():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 16, reduce dim multi shard.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
layout = Layout((2, 2, 2, 2), ("dp", "sp", "vp", "mp"))
|
||||
layout1 = (layout("dp", ("sp", "mp")), layout(("sp", "mp"), "vp"))
|
||||
net = Net(w, layout1)
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('w1', [256, 512])
|
||||
|
||||
def test_layout_extend_reduce_axis_multi_shard_reduce_scatter():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 16, reduce dim multi shard.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
layout = Layout((2, 2, 2, 2), ("dp", "sp", "vp", "mp"))
|
||||
layout1 = (layout("dp", ("sp", "mp")), layout(("sp", "mp"), "vp"))
|
||||
out_layout = (layout(("dp", "sp", "mp"), "vp"),)
|
||||
net = GradWrap(NetWithLoss(Net(w, layout1, out_layout)))
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('network.network.w', [256, 512])
|
||||
assert validator.check_node_inputs_has('ReduceScatter-0', ['MatMul'])
|
||||
|
||||
def test_layout_extend_reduce_axis_multi_shard_reduce_scatter_opt_shard():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 16, reduce dim multi shard, enable optimizer parallel.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0,
|
||||
enable_parallel_optimizer=True)
|
||||
layout = Layout((2, 2, 2, 2), ("dp", "sp", "vp", "mp"))
|
||||
layout1 = (layout("dp", ("sp", "mp")), layout(("sp", "mp"), "vp"))
|
||||
out_layout = (layout(("dp", "sp", "mp"), "vp"),)
|
||||
net = GradWrap(NetWithLoss(Net(w, layout1, out_layout)))
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
context.reset_auto_parallel_context()
|
||||
assert validator.check_parameter_layout('network.network.w',
|
||||
([2, 2, 2, 2], [2, 0, 1], [256, 512], 0, True, '2-16557109384257890687'))
|
||||
assert validator.check_node_inputs_has('ReduceScatter-0', ['MatMul'])
|
||||
|
||||
def test_layout_extend_reduce_axis_multi_shard_reduce_scatter_opt_shard_not_full():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 32, reduce dim multi shard, enable optimizer parallel, opt_shard=2.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0,
|
||||
enable_parallel_optimizer=True,
|
||||
parallel_optimizer_config={"optimizer_weight_shard_size": 2})
|
||||
layout = Layout((4, 2, 2, 2), ("dp", "sp", "vp", "mp"))
|
||||
layout1 = (layout("dp", ("sp", "mp")), layout(("sp", "mp"), "vp"))
|
||||
out_layout = (layout(("dp", "sp", "mp"), "vp"),)
|
||||
net = GradWrap(NetWithLoss(Net(w, layout1, out_layout)))
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
context.reset_auto_parallel_context()
|
||||
assert validator.check_parameter_layout('network.network.w',
|
||||
([4, 2, 2, 2], [2, 0, 1], [256, 512], 0, True, '2-16557109384257890687'))
|
||||
assert validator.check_node_inputs_has('ReduceScatter-0', ['MatMul'])
|
||||
|
||||
def test_layout_extend_reduce_axis_multi_shard_reduce_scatter_including_dev1():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8, reduce dim multi shard.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2, 1), ("dp", "sp", "vp", "mp"))
|
||||
layout1 = (layout("dp", ("sp", "mp")), layout(("sp", "mp"), "vp"))
|
||||
out_layout = (layout(("dp", "sp", "mp"), "vp"),)
|
||||
net = GradWrap(NetWithLoss(Net(w, layout1, out_layout)))
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('network.network.w', [512, 512])
|
||||
assert validator.check_node_inputs_has('ReduceScatter-0', ['MatMul'])
|
||||
|
||||
def test_layout_extend_reduce_axis_multi_shard_reduce_scatter_including_axis_none():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8, reduce dim multi shard.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout("dp", ("sp", "mp")), layout(("sp", "mp"), "None"))
|
||||
out_layout = (layout(("dp", "sp", "mp"), "None"),)
|
||||
net = GradWrap(NetWithLoss(Net(w, layout1, out_layout)))
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('network.network.w', [256, 1024])
|
||||
assert validator.check_node_inputs_has('ReduceScatter-0', ['MatMul'])
|
||||
|
||||
def test_layout_extend_reduce_axis_multi_shard_reduce_scatter_including_reduce_axis_none():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8, reduce dim multi shard with None.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout("dp", ("sp", "None")), layout(("sp", "None"), "mp"))
|
||||
out_layout = (layout(("dp", "sp", "None"), "mp"),)
|
||||
net = GradWrap(NetWithLoss(Net(w, layout1, out_layout)))
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('network.network.w', [512, 512])
|
||||
assert validator.check_node_inputs_has('ReduceScatter-0', ['MatMul'])
|
||||
|
||||
def test_layout_extend_reduce_axis_multi_shard_reduce_scatter_including_reduce_axis_none_and_not_full():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8, reduce dim multi shard with None, and not shard full.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout("dp", ("sp", "None")), layout(("sp", "None"), "None"))
|
||||
out_layout = (layout(("dp", "sp", "None"), "None"),)
|
||||
net = GradWrap(NetWithLoss(Net(w, layout1, out_layout)))
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('network.network.w', [512, 1024])
|
||||
assert validator.check_node_inputs_has('ReduceScatter-0', ['MatMul'])
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright 2024 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.parallel.shard import Layout
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
from parallel.utils.utils import ParallelValidator
|
||||
|
||||
|
||||
def setup_function():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, y):
|
||||
predict = self.network(y)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, y):
|
||||
return grad_all(self.network)(y)
|
||||
|
||||
|
||||
def compile_net(net, input_x):
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
phase, _ = _cell_graph_executor.compile(net, input_x)
|
||||
return phase
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, weight, in_layout, out_layout=None):
|
||||
super().__init__()
|
||||
self.add = P.Add().shard(in_strategy=in_layout, out_strategy=out_layout)
|
||||
self.relu = P.ReLU()
|
||||
self.w = Parameter(weight, "w1")
|
||||
|
||||
def construct(self, y):
|
||||
out1 = self.add(y, self.w)
|
||||
out2 = self.relu(out1)
|
||||
out = out1 + out2
|
||||
return out
|
||||
|
||||
x = Tensor(np.ones([1024, 1024]), dtype=ms.float32)
|
||||
w = Tensor(np.ones([1024, 1024]), dtype=ms.float32)
|
||||
|
||||
input_1024 = Tensor(np.ones([1024]), dtype=ms.float32)
|
||||
input_1_1024 = Tensor(np.ones([1, 1024]), dtype=ms.float32)
|
||||
input_1024_1024 = Tensor(np.ones([1024, 1024]), dtype=ms.float32)
|
||||
|
||||
|
||||
def test_layout_extend_add_same_shape_same_shard():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout(("dp", "sp"), "mp"), layout(("dp", "sp"), "mp"))
|
||||
first, second = input_1024_1024, x
|
||||
net = Net(second, layout1)
|
||||
phase = compile_net(net, first)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('w1', [256, 512])
|
||||
|
||||
def test_layout_extend_add_same_shape_wrong_shard():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8.
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout(("dp", "sp"), "mp"), layout(("dp", "mp"), "sp"))
|
||||
first, second = input_1024_1024, x
|
||||
net = Net(second, layout1)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, first)
|
||||
|
||||
def test_layout_extend_add_same_dim_broadcast():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8.
|
||||
Expectation: compile success, second input broadcast
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout(("dp", "sp"), "mp"), layout("None", "mp"))
|
||||
first, second = input_1024_1024, input_1_1024
|
||||
net = Net(second, layout1)
|
||||
phase = compile_net(net, first)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('w1', [1, 512])
|
||||
|
||||
def test_layout_extend_add_different_dim_broadcast():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8.
|
||||
Expectation: compile success, second input broadcast
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout(("dp", "sp"), "mp"), layout("mp",))
|
||||
first, second = input_1024_1024, input_1024
|
||||
net = Net(second, layout1)
|
||||
phase = compile_net(net, first)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('w1', [512])
|
||||
|
||||
def test_layout_extend_add_different_dim_broadcast_failed():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8.
|
||||
Expectation: compile success, second input broadcast
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout(("dp", "sp"), "mp"), layout("None",))
|
||||
first, second = input_1024_1024, input_1024
|
||||
net = Net(second, layout1)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, first)
|
||||
|
||||
def test_layout_extend_add_same_shape_same_shard_outputlayout_not_allowed():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8.
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout(("dp", "sp"), "mp"), layout(("dp", "sp"), "mp"))
|
||||
out_layout = (layout(("dp", "sp"), "mp"),)
|
||||
first, second = input_1024_1024, x
|
||||
net = Net(second, layout1, out_layout)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, first)
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright 2024 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.parallel.shard import Layout
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
from parallel.utils.utils import ParallelValidator
|
||||
|
||||
|
||||
def setup_function():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, y):
|
||||
predict = self.network(y)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, y):
|
||||
return grad_all(self.network)(y)
|
||||
|
||||
|
||||
def compile_net(net, input_x):
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
phase, _ = _cell_graph_executor.compile(net, input_x)
|
||||
return phase
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, weight, in_layout, out_layout=None):
|
||||
super().__init__()
|
||||
self.bias_add = P.BiasAdd().shard(in_strategy=in_layout, out_strategy=out_layout)
|
||||
self.relu = P.ReLU()
|
||||
self.w = Parameter(weight, "w1")
|
||||
|
||||
def construct(self, y):
|
||||
out1 = self.bias_add(y, self.w)
|
||||
out2 = self.relu(out1)
|
||||
out = out1 + out2
|
||||
return out
|
||||
|
||||
x_1_1024 = Tensor(np.ones([1, 1024]), dtype=ms.float32)
|
||||
x_1024_1024 = Tensor(np.ones([1024, 1024]), dtype=ms.float32)
|
||||
bias = Tensor(np.ones([1024]), dtype=ms.float32)
|
||||
|
||||
def test_layout_extend_bias_add_shard():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout("mp", ("dp", "sp")), layout(("dp", "sp"),))
|
||||
first, second = x_1024_1024, bias
|
||||
net = Net(second, layout1)
|
||||
phase = compile_net(net, first)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('w1', [256])
|
||||
|
||||
def test_layout_extend_bias_add_channel_shard_different():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8.
|
||||
Expectation: compile failed, channel shard different
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout(("dp", "sp"), "mp"), layout("sp",))
|
||||
first, second = x_1024_1024, bias
|
||||
net = Net(second, layout1)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, first)
|
||||
|
||||
def test_layout_extend_bias_add_self_define_outputlayout_not_allowed():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8.
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout(("dp", "sp"), "mp"), layout(("dp", "sp"), "mp"))
|
||||
out_layout = (layout(("dp", "sp"), "mp"),)
|
||||
first, second = x_1024_1024, bias
|
||||
net = Net(second, layout1, out_layout)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, first)
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2024 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.parallel.shard import Layout
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
from parallel.utils.utils import ParallelValidator
|
||||
|
||||
|
||||
def setup_function():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, y):
|
||||
predict = self.network(y)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, y):
|
||||
return grad_all(self.network)(y)
|
||||
|
||||
|
||||
def compile_net(net, input_x):
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
phase, _ = _cell_graph_executor.compile(net, input_x)
|
||||
return phase
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, in_layout, out_layout=None):
|
||||
super().__init__()
|
||||
self.gelu = P.GeLU().shard(in_strategy=in_layout, out_strategy=out_layout)
|
||||
|
||||
def construct(self, y):
|
||||
out = self.gelu(y)
|
||||
return out
|
||||
|
||||
|
||||
x = Tensor(np.ones([1024, 1024]), dtype=ms.float32)
|
||||
|
||||
|
||||
def test_layout_extend_base():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 4.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
|
||||
layout = Layout((2, 2), ("dp", "mp"))
|
||||
layout1 = (layout("dp", "mp"),)
|
||||
net = Net(layout1)
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_inputs("GeLU-0", ["StridedSlice-1"])
|
||||
|
||||
|
||||
def test_layout_extend_batch_multi_shard():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8, batch dim multi shard.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout(("dp", "mp"), "sp"),)
|
||||
net = Net(layout1)
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_inputs("GeLU-0", ["Reshape-1"])
|
||||
|
||||
|
||||
def test_layout_extend_reduce_axis_multi_shard():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8, reduce dim multi shard.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout("dp", ("mp", "sp")),)
|
||||
net = Net(layout1)
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_inputs("GeLU-0", ["Reshape-1"])
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2024 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.parallel.shard import Layout
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
from parallel.utils.utils import ParallelValidator
|
||||
|
||||
def setup_function():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, y):
|
||||
predict = self.network(y)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, y):
|
||||
return grad_all(self.network)(y)
|
||||
|
||||
|
||||
def compile_net(net, input_x):
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
phase, _ = _cell_graph_executor.compile(net, input_x)
|
||||
return phase
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, x_gamma, x_beta, in_layout, out_layout=None, begin_norm_axis=1):
|
||||
super().__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.layernorm = P.LayerNorm(begin_norm_axis).shard(in_strategy=in_layout,
|
||||
out_strategy=out_layout)
|
||||
self.gamma = Parameter(x_gamma, "gamma")
|
||||
self.beta = Parameter(x_beta, "beta")
|
||||
|
||||
def construct(self, y):
|
||||
out1, _, _ = self.layernorm(y, self.gamma, self.beta)
|
||||
out2 = self.relu(out1)
|
||||
out = out1 + out2
|
||||
return out
|
||||
|
||||
x = Tensor(np.ones([128, 16, 32]), dtype=ms.float32)
|
||||
gamma = Tensor(np.ones([16, 32]), dtype=ms.float32)
|
||||
beta = Tensor(np.ones([16, 32]), dtype=ms.float32)
|
||||
|
||||
def test_layout_layernorm_base():
|
||||
"""
|
||||
Feature: test layout extend
|
||||
Description: dev_num is 8.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
|
||||
layout1 = (layout("dp", "sp", "None"), layout("sp", "None"), layout("sp", "None"))
|
||||
net = Net(gamma, beta, layout1, begin_norm_axis=2)
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('gamma', [8, 32])
|
||||
|
||||
|
||||
def test_layout_layernorm_multi_shard():
|
||||
"""
|
||||
Feature: test layout extend for multi shard
|
||||
Description: dev_num is 16.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
layout = Layout((2, 2, 2, 2), ("dp", "sp", "vp", "mp"))
|
||||
layout1 = (layout(("dp", "mp"), "sp", "None"), layout("sp", "None"), layout("sp", "None"))
|
||||
net = Net(gamma, beta, layout1, begin_norm_axis=2)
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('gamma', [8, 32])
|
||||
|
||||
|
||||
def test_layout_layernorm_multi_shard1():
|
||||
"""
|
||||
Feature: test layout extend for multi shard
|
||||
Description: dev_num is 16.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
layout = Layout((2, 2, 2, 2), ("dp", "sp", "vp", "mp"))
|
||||
layout1 = (layout("dp", ("sp", "mp"), "None"), layout(("sp", "mp"), "None"), layout(("sp", "mp"), "None"))
|
||||
net = Net(gamma, beta, layout1, begin_norm_axis=2)
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('gamma', [4, 32])
|
||||
|
||||
|
||||
def test_layout_layernorm_out_check():
|
||||
"""
|
||||
Feature: test layout extend for output layout check
|
||||
Description: dev_num is 16.
|
||||
Expectation: compile failed, throw exception
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
layout = Layout((2, 2, 2, 2), ("dp", "sp", "vp", "mp"))
|
||||
layout1 = (layout(("dp", "mp"), "sp", "None"), layout("sp", "None"), layout("sp", "None"))
|
||||
out_layout = (layout(("dp", "mp"), "sp", "None"), layout(("dp", "mp"), "sp", "None"),
|
||||
layout(("dp", "mp"), "sp", "None"))
|
||||
net = Net(gamma, beta, layout1, out_layout=out_layout, begin_norm_axis=2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, x)
|
||||
|
||||
|
||||
def test_layout_layernorm_multi_shard_with_grad():
|
||||
"""
|
||||
Feature: test layout extend with grad
|
||||
Description: dev_num is 16.
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
layout = Layout((2, 2, 2, 2), ("dp", "sp", "vp", "mp"))
|
||||
layout1 = (layout(("dp", "mp"), "sp", "None"), layout("sp", "None"), layout("sp", "None"))
|
||||
net = GradWrap(NetWithLoss(Net(gamma, beta, layout1, begin_norm_axis=2)))
|
||||
phase = compile_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_parameter_shape('network.network.gamma', [8, 32])
|
|
@ -24,13 +24,13 @@ def test_load_tensor():
|
|||
dev_mat = [2, 3]
|
||||
tensor_map = [1, -1]
|
||||
hccl.rank_id = 5
|
||||
tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
|
||||
tensor_slice = _load_tensor(tensor, dev_mat, tensor_map, [2, 3])
|
||||
expected_tensor = Tensor([[4, 5, 6]])
|
||||
if expected_tensor.__str__() != tensor_slice.__str__():
|
||||
raise AssertionError
|
||||
|
||||
hccl.rank_id = 2
|
||||
tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
|
||||
tensor_slice = _load_tensor(tensor, dev_mat, tensor_map, [2, 3])
|
||||
expected_tensor = Tensor([[1, 2, 3]])
|
||||
if expected_tensor.__str__() != tensor_slice.__str__():
|
||||
raise AssertionError
|
||||
|
|
|
@ -82,6 +82,11 @@ def compile_net_no_bias(net, x, y):
|
|||
|
||||
|
||||
def test_no_grad():
|
||||
"""
|
||||
Feature: test no grad
|
||||
Description: dev_num is 8, test no grad.
|
||||
Expectation: compile success
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
|
@ -107,6 +112,11 @@ def test_no_grad():
|
|||
|
||||
|
||||
def test_grad_sens_parameter_type():
|
||||
"""
|
||||
Feature: test grad sens parameter
|
||||
Description: dev_num is 8, test grad sens parameter.
|
||||
Expectation: compile success
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
|
@ -142,6 +152,11 @@ def test_grad_sens_parameter_type():
|
|||
|
||||
|
||||
def test_grad_sens_tensor_type():
|
||||
"""
|
||||
Feature: test grad sens tensor type
|
||||
Description: dev_num is 8, test grad sens tensor type.
|
||||
Expectation: compile success
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
|
@ -167,6 +182,11 @@ def test_grad_sens_tensor_type():
|
|||
|
||||
|
||||
def test_grad_sens_scalar_broadcast():
|
||||
"""
|
||||
Feature: test grad sens scalar broadcast
|
||||
Description: dev_num is 8, test grad sens scalar broadcast.
|
||||
Expectation: compile success
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy0, strategy1):
|
||||
super().__init__()
|
||||
|
|
|
@ -73,7 +73,7 @@ class ParallelValidator:
|
|||
|
||||
if param_name not in self._parameter_layout_dict.keys():
|
||||
return False
|
||||
return self._parameter_layout_dict[param_name][0:6] == layout
|
||||
return self._parameter_layout_dict[param_name][:6] == layout
|
||||
|
||||
def check_parameter_shape(self, param_name: str, shape: [tuple, list]) -> bool:
|
||||
"""Verify parameter shape"""
|
||||
|
|
Loading…
Reference in New Issue