!16727 handle duplicate code

From: @yangzhenzhang
Reviewed-by: @stsuteng,@jjfeing
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2021-05-26 10:58:54 +08:00 committed by Gitee
commit 1e8868cecc
66 changed files with 101 additions and 1599 deletions

View File

@ -287,23 +287,6 @@ Status ActivationBase::InferTensorMap() {
return SUCCESS;
}
Status ActivationBase::InferTensorInfo() {
TensorLayout input_tensor_layout, output_tensor_layout;
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) ||
(output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]))) {
MS_LOG(ERROR) << name_ << ": init tensor layout failed";
return FAILED;
}
TensorInfo input_tensor_info(input_tensor_layout);
TensorInfo output_tensor_info(output_tensor_layout);
inputs_tensor_info_.push_back(input_tensor_info);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
Status DropoutInfo::GetAttrs() {
auto iter0 = attrs_.find(SEED0);
if (iter0 != attrs_.end()) {
@ -328,29 +311,29 @@ Status DropoutInfo::GetAttrs() {
return SUCCESS;
}
Status DropoutInfo::InferTensorInfo() {
// infer tensor shape
Shape input_shape = inputs_shape_.at(0);
Status DropoutInfo::InferTensorMap() {
Shape tensor_map_in;
size_t size = inputs_shape_.at(0).size();
// such as 4: tensor_map_index [3,2,1,0]
for (size_t i = 0; i < size; ++i) {
tensor_map_in.push_back((int64_t)(size - i - 1));
}
// infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim();
// dropout has two outputs
Strategys outputs_strategy = {inputs_strategy.at(0), inputs_strategy.at(0)};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
inputs_tensor_map_.push_back(tensor_map_in);
outputs_tensor_map_.push_back(tensor_map_in);
outputs_tensor_map_.push_back(tensor_map_in); // the dropout has two outputs
return SUCCESS;
}
Status DropoutInfo::InferAsLossDivisor() {
if (outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The size of outputs tensor map is empty";
return FAILED;
}
Shape input_slice_shape = inputs_slice_shape.at(0);
TensorLayout input_tensor_layout;
if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) {
return FAILED;
}
TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
inputs_tensor_info_.push_back(input_tensor_info);
// the two outputs of dropout all have the same tensor_info as input
outputs_tensor_info_.push_back(input_tensor_info);
outputs_tensor_info_.push_back(input_tensor_info);
as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_)
<< ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0])
<< ", as_loss_divisor_ is " << as_loss_divisor_;
return SUCCESS;
}
@ -514,58 +497,6 @@ Status ExpandDimsInfo::InferTensorStrategy() {
return SUCCESS;
}
Status ExpandDimsInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty";
return FAILED;
}
if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty";
return FAILED;
}
Shape input_shape = inputs_shape_[0];
Shape output_shape = outputs_shape_[0];
// infer slice shape
if (InferTensorStrategy() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer tensor strategy failed";
return FAILED;
}
Shapes inputs_slice_shape, outputs_slice_shape;
if (InferSliceShape(inputs_strategy_, outputs_strategy_, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer slice shape failed";
return FAILED;
}
if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) {
MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty";
return FAILED;
}
Shape input_slice_shape = inputs_slice_shape[0];
Shape output_slice_shape = outputs_slice_shape[0];
TensorLayout input_tensor_layout, output_tensor_layout;
if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed";
return FAILED;
}
if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed";
return FAILED;
}
TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape);
inputs_tensor_info_.push_back(input_tensor_info);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
Status ExpandDimsInfo::InferMirrorOps() {
mirror_ops_.clear();
@ -676,65 +607,6 @@ Status SqueezeInfo::InferTensorMap() {
return SUCCESS;
}
Status SqueezeInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty";
return FAILED;
}
if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty";
return FAILED;
}
Shape input_shape = inputs_shape_[0];
Shape output_shape = outputs_shape_[0];
// infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim();
Dimensions output_strategy;
std::vector<int64_t> axis = GetValue<const std::vector<int64_t>>(axis_);
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
auto iter = std::find(axis.begin(), axis.end(), SizeToLong(i));
if (iter == axis.end()) {
output_strategy.push_back(inputs_strategy[0].at(i));
}
}
Strategys outputs_strategy = {output_strategy};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer slice shape failed";
return FAILED;
}
if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) {
MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty";
return FAILED;
}
Shape input_slice_shape = inputs_slice_shape[0];
Shape output_slice_shape = outputs_slice_shape[0];
// infer tensor layout
TensorLayout input_tensor_layout, output_tensor_layout;
if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed";
return FAILED;
}
if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed";
return FAILED;
}
TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape);
inputs_tensor_info_.push_back(input_tensor_info);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
Status SqueezeInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed.";

View File

@ -43,7 +43,6 @@ class ActivationBase : public OperatorInfo {
Status InferMirrorOps() override;
Status InferForwardCommunication() override;
Status InferTensorMap() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
};
@ -214,7 +213,6 @@ class ExpandDimsInfo : public ActivationOther {
protected:
Status GetAttrs() override;
Status InferTensorMap() override;
Status InferTensorInfo() override;
Status InferMirrorOps() override;
Status InferTensorStrategy();
@ -236,7 +234,6 @@ class SqueezeInfo : public ActivationOther {
Status GetAttrs() override;
Status InferReplaceOps(const StrategyPtr &strategy);
Status InferTensorMap() override;
Status InferTensorInfo() override;
Status Init(const StrategyPtr &strategy) override;
private:
@ -270,8 +267,9 @@ class DropoutInfo : public ActivationOther {
protected:
Status GetAttrs() override;
Status InferTensorInfo() override;
Status InferTensorMap() override;
Status InferReplaceOps(const StrategyPtr &strategy);
Status InferAsLossDivisor() override;
private:
int64_t seed0_ = 0;

View File

@ -176,120 +176,6 @@ Status ArithmeticBase::InferTensorMap() {
return SUCCESS;
}
Status ArithmeticBase::InferMirrorOps() {
mirror_ops_.clear();
Shape input_a_tensor_map = inputs_tensor_map_.at(0);
Shape input_b_tensor_map = inputs_tensor_map_.at(1);
std::vector<Group> input_a_group, input_b_group;
if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group for input a failed.";
return FAILED;
}
if (CreateGroupByTensorMap(input_b_tensor_map, &input_b_group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group for input b failed.";
return FAILED;
}
OperatorVector op_for_input_a, op_for_input_b;
if (input_a_group.empty() && input_b_group.empty()) {
MS_LOG(INFO) << name_ << " : The mirror group is empty.";
return SUCCESS;
}
if (!input_a_group.empty()) {
op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum());
MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name();
}
if (!input_b_group.empty()) {
op_for_input_b = CreateMirrorOps(input_b_group[0].name(), input_b_group[0].GetDevNum());
MS_LOG(INFO) << name_ << " : Create the mirror ops for input b success, group is " << input_b_group[0].name();
}
mirror_ops_.push_back(op_for_input_a);
mirror_ops_.push_back(op_for_input_b);
return SUCCESS;
}
Status ArithmeticBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout,
const Shape &dev_matrix_array) {
if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) {
MS_LOG(ERROR) << name_ << " : The layout is null.";
return FAILED;
}
TensorMap input_a_tensor_map_array = inputs_tensor_map_.at(0);
TensorMap input_b_tensor_map_array = inputs_tensor_map_.at(1);
TensorMap out_tensor_map_array = outputs_tensor_map_.at(0);
Shape input_a_shape_array = inputs_shape_.at(0);
Shape input_b_shape_array = inputs_shape_.at(1);
Shape out_shape_array = outputs_shape_.at(0);
TensorLayout input_a_tensor_layout, input_b_tensor_layout, out_tensor_layout;
if (input_a_tensor_layout.InitFromVector(dev_matrix_array, input_a_tensor_map_array, input_a_shape_array) !=
SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create tensor layout for input a failed.";
return FAILED;
}
if (input_b_tensor_layout.InitFromVector(dev_matrix_array, input_b_tensor_map_array, input_b_shape_array) !=
SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create tensor layout for input b failed.";
return FAILED;
}
if (out_tensor_layout.InitFromVector(dev_matrix_array, out_tensor_map_array, out_shape_array) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create tensor layout for output failed.";
return FAILED;
}
inputs_layout->push_back(input_a_tensor_layout);
inputs_layout->push_back(input_b_tensor_layout);
outputs_layout->push_back(out_tensor_layout);
return SUCCESS;
}
Status ArithmeticBase::InferTensorInfo() {
// infer tensor shape
Shape input_a_shape = inputs_shape_.at(0);
Shape input_b_shape = inputs_shape_.at(1);
Shape output_shape = outputs_shape_.at(0);
// infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape;
Strategys expend_strategy = ExpendStrategy(strategy_);
Dimensions sub_a_expend_strategy = expend_strategy.at(0);
Dimensions sub_b_expend_strategy = expend_strategy.at(1);
Strategys inputs_strategy = strategy_->GetInputDim();
Shape dev_shape;
for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) {
if (sub_a_expend_strategy[i] != sub_b_expend_strategy[i]) {
dev_shape.push_back(sub_a_expend_strategy[i] * sub_b_expend_strategy[i]);
} else {
dev_shape.push_back(sub_a_expend_strategy[i]);
}
}
Strategys outputs_strategy = {dev_shape};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
return FAILED;
}
Shape input_a_slice_shape = inputs_slice_shape.at(0);
Shape input_b_slice_shape = inputs_slice_shape.at(1);
Shape output_slice_shape = outputs_slice_shape.at(0);
// infer tensor layout
TensorLayouts inputs_layout, outputs_layout;
if (InferTensorLayout(&inputs_layout, &outputs_layout, dev_matrix_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Infer tensor layout failed.";
return FAILED;
}
TensorInfo input_a_tensor_info(inputs_layout.at(0), input_a_shape, input_a_slice_shape);
TensorInfo input_b_tensor_info(inputs_layout.at(1), input_b_shape, input_b_slice_shape);
TensorInfo out_tensor_info(outputs_layout.at(0), output_shape, output_slice_shape);
inputs_tensor_info_.push_back(input_a_tensor_info); // inputs_a
inputs_tensor_info_.push_back(input_b_tensor_info); // inputs_b
outputs_tensor_info_.push_back(out_tensor_info); // output
return SUCCESS;
}
Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
Status ArithmeticBase::GenerateStrategies(int64_t stage_id) {

View File

@ -44,12 +44,9 @@ class ArithmeticBase : public OperatorInfo {
protected:
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array);
Shapes InferExpendShape();
};

View File

@ -107,28 +107,6 @@ Strategys BatchParallelInfo::GetOutputsStrategy() {
return outputs_strategy;
}
Status BatchParallelInfo::InferTensorInfo() {
for (size_t i = 0; i < strategy_->GetInputNumber(); i++) {
MS_LOG(INFO) << name_ << " : The input size is " << strategy_->GetInputNumber();
TensorLayout tensor_layout_in;
if (tensor_layout_in.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(i), inputs_shape_.at(i)) != SUCCESS) {
return FAILED;
}
TensorInfo tensor_info_in(tensor_layout_in);
inputs_tensor_info_.push_back(tensor_info_in);
}
for (size_t i = 0; i < outputs_shape_.size(); i++) {
TensorLayout tensor_layout_out;
if (tensor_layout_out.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(i), outputs_shape_.at(i)) !=
SUCCESS) {
return FAILED;
}
TensorInfo tensor_info_out(tensor_layout_out);
outputs_tensor_info_.push_back(tensor_info_out);
}
return SUCCESS;
}
Status BatchParallelInfo::GetAttrs() { return SUCCESS; }
Status BatchParallelInfo::Init(const StrategyPtr &strategy) {

View File

@ -45,7 +45,6 @@ class BatchParallelInfo : public OperatorInfo {
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;

View File

@ -73,109 +73,6 @@ Status BiasAddInfo::InferTensorMap() {
return SUCCESS;
}
Status BiasAddInfo::InferMirrorOps() {
mirror_ops_.clear();
Shape input_a_tensor_map = inputs_tensor_map_.at(0);
Shape input_b_tensor_map = inputs_tensor_map_.at(1);
std::vector<Group> input_a_group, input_b_group;
if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group for input a failed.";
return FAILED;
}
if (CreateGroupByTensorMap(input_b_tensor_map, &input_b_group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group for input b failed.";
return FAILED;
}
OperatorVector op_for_input_a, op_for_input_b;
if (input_a_group.empty() && input_b_group.empty()) {
MS_LOG(INFO) << name_ << " : The mirror group is empty.";
return SUCCESS;
}
if (!input_a_group.empty()) {
op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum());
MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name();
}
if (!input_b_group.empty()) {
op_for_input_b = CreateMirrorOps(input_b_group[0].name(), input_b_group[0].GetDevNum());
MS_LOG(INFO) << name_ << " : Create the mirror ops for input b success, group is " << input_b_group[0].name();
}
mirror_ops_.push_back(op_for_input_a);
mirror_ops_.push_back(op_for_input_b);
return SUCCESS;
}
Status BiasAddInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout,
const Shape &dev_matrix_array) {
if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) {
MS_LOG(ERROR) << name_ << " : The layout is null.";
return FAILED;
}
TensorMap input_a_tensor_map_array = inputs_tensor_map_.at(0);
TensorMap input_b_tensor_map_array = inputs_tensor_map_.at(1);
TensorMap out_tensor_map_array = outputs_tensor_map_.at(0);
Shape input_a_shape_array = inputs_shape_.at(0);
Shape input_b_shape_array = inputs_shape_.at(1);
Shape out_shape_array = outputs_shape_.at(0);
TensorLayout input_a_tensor_layout, input_b_tensor_layout, out_tensor_layout;
if (input_a_tensor_layout.InitFromVector(dev_matrix_array, input_a_tensor_map_array, input_a_shape_array) !=
SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create tensor layout for input a failed.";
return FAILED;
}
if (input_b_tensor_layout.InitFromVector(dev_matrix_array, input_b_tensor_map_array, input_b_shape_array) !=
SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create tensor layout for input b failed.";
return FAILED;
}
if (out_tensor_layout.InitFromVector(dev_matrix_array, out_tensor_map_array, out_shape_array) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create tensor layout for output failed.";
return FAILED;
}
inputs_layout->push_back(input_a_tensor_layout);
inputs_layout->push_back(input_b_tensor_layout);
outputs_layout->push_back(out_tensor_layout);
return SUCCESS;
}
Status BiasAddInfo::InferTensorInfo() {
// infer tensor shape
Shape input_a_shape = inputs_shape_.at(0);
Shape input_b_shape = inputs_shape_.at(1);
Shape output_shape = outputs_shape_.at(0);
// infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim();
Strategys outputs_strategy = {inputs_strategy.at(0)};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
return FAILED;
}
Shape input_a_slice_shape = inputs_slice_shape.at(0);
Shape input_b_slice_shape = inputs_slice_shape.at(1);
Shape output_slice_shape = outputs_slice_shape.at(0);
// infer tensor layout
TensorLayouts inputs_layout, outputs_layout;
if (InferTensorLayout(&inputs_layout, &outputs_layout, dev_matrix_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Infer tensor layout failed.";
return FAILED;
}
TensorInfo input_a_tensor_info(inputs_layout.at(0), input_a_shape, input_a_slice_shape);
TensorInfo input_b_tensor_info(inputs_layout.at(1), input_b_shape, input_b_slice_shape);
TensorInfo out_tensor_info(outputs_layout.at(0), output_shape, output_slice_shape);
inputs_tensor_info_.push_back(input_a_tensor_info); // inputs_a
inputs_tensor_info_.push_back(input_b_tensor_info); // inputs_b
outputs_tensor_info_.push_back(out_tensor_info); // output
return SUCCESS;
}
Status BiasAddInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
Status BiasAddInfo::GenerateStrategies(int64_t stage_id) {

View File

@ -46,12 +46,9 @@ class BiasAddInfo : public OperatorInfo {
protected:
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array);
};
} // namespace parallel
} // namespace mindspore

View File

@ -112,56 +112,6 @@ Status BroadcastToInfo::InferTensorMap() {
return SUCCESS;
}
Status BroadcastToInfo::InferMirrorOps() {
mirror_ops_.clear();
if (inputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
return FAILED;
}
Shape input_tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group for input failed.";
return FAILED;
}
if (group.empty()) {
MS_LOG(INFO) << name_ << ": The mirror group is empty.";
return SUCCESS;
}
OperatorVector input_op;
input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
mirror_ops_.push_back(input_op);
return SUCCESS;
}
Status BroadcastToInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout;
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
Status BroadcastToInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
Status BroadcastToInfo::GenerateStrategies(int64_t stage_id) {

View File

@ -48,9 +48,7 @@ class BroadcastToInfo : public OperatorInfo {
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status ComputeReplaceGraph(const CNodePtr &cnode);

View File

@ -137,60 +137,6 @@ Status ConcatInfo::InferTensorMap() {
return SUCCESS;
}
Status ConcatInfo::InferMirrorOps() {
mirror_ops_.clear();
if (inputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
return FAILED;
}
Shape input_tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group for input failed.";
return FAILED;
}
if (group.empty()) {
MS_LOG(INFO) << name_ << ": The mirror group is empty.";
return SUCCESS;
}
OperatorVector input_op;
input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
mirror_ops_.push_back(input_op);
}
return SUCCESS;
}
Status ConcatInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout;
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
}
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
void ConcatInfo::ReComputeBatchSplitFlagList() {
for (size_t i = 0; i < inputs_shape_.size(); i++) {
split_flag_list_[i] = true;

View File

@ -45,9 +45,7 @@ class ConcatInfo : public OperatorInfo {
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;

View File

@ -88,38 +88,6 @@ Status DropoutDoMaskInfo::InferTensorMap() {
return SUCCESS;
}
Status DropoutDoMaskInfo::InferTensorInfo() {
if (inputs_shape_.size() != 3) {
MS_LOG(ERROR) << name_ << ": Invalid inputs shape size " << inputs_shape_.size();
return FAILED;
}
if (strategy_ == nullptr) {
MS_LOG(ERROR) << name_ << ": The strategy is null";
return FAILED;
}
Shape input_0_shape = inputs_shape_[0];
if (inputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
return FAILED;
}
TensorLayout input_0_tensor_layout;
if (input_0_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_0_shape) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init tensor layout failed";
return FAILED;
}
TensorInfo input_0_tensor_info(input_0_tensor_layout);
// input_1 do not need tensor info
inputs_tensor_info_.push_back(input_0_tensor_info); // input_0
outputs_tensor_info_.push_back(input_0_tensor_info); // output
return SUCCESS;
}
Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
return SetCostUnderStrategyBase(strategy);
}

View File

@ -49,7 +49,6 @@ class DropoutDoMaskInfo : public OperatorInfo {
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorMap() override;
Status GetAttrs() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
};

View File

@ -111,32 +111,6 @@ Status GatherNdInfo::InferTensorMap() {
return SUCCESS;
}
Status GatherNdInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout;
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
}
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
void GatherNdInfo::ReComputeBatchSplitFlagList() {
split_flag_list_[0] = false;
split_flag_list_[1] = true;

View File

@ -46,7 +46,6 @@ class GatherNdInfo : public OperatorInfo {
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
};

View File

@ -45,7 +45,7 @@ class GetNextInfo : public OperatorInfo {
Status GetAttrs() override;
Status InferTensorMap() override;
Status InferTensorLayout(TensorLayouts *outputs_layout);
Status InferTensorInfo() override;
Status InferTensorInfo() override; // The GetNext() has not input, so need override the InferTensorInfo().
Status InferDevMatrixShape() override;
Status InferMirrorOps() override { return SUCCESS; }
Status InferForwardCommunication() override { return SUCCESS; }

View File

@ -62,28 +62,6 @@ Status L2NormalizeInfo::GetAttrs() {
return SUCCESS;
}
Status L2NormalizeInfo::InferMirrorOps() {
mirror_ops_.clear();
Shape input_tensor_map = inputs_tensor_map_.at(0);
std::vector<Group> input_group;
if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group failed.";
return FAILED;
}
OperatorVector op_for_weight;
if (input_group.empty()) {
MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
return SUCCESS;
} else {
op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum());
mirror_ops_.push_back(op_for_weight);
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group is " << input_group[0].name();
}
return SUCCESS;
}
Status L2NormalizeInfo::GenerateStrategies(int64_t stage_id) {
if (GetAttrs() != SUCCESS) {
MS_LOG(ERROR) << name_ << " : GetAttrs failed.";

View File

@ -39,7 +39,6 @@ class L2NormalizeInfo : public Activation {
protected:
Status GetAttrs() override;
Status InferMirrorOps() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
private:

View File

@ -22,6 +22,13 @@
namespace mindspore {
namespace parallel {
// the layernorm has three outputs
// if the shape of input is [A, B, C, D], the shape of first output is [A, B, C, D]
// if the begin-norm-axis is 0, the shape of second output is: [1, 1, 1, 1]
// if the begin-norm-axis is 1, the shape of second output is: [A, 1, 1, 1]
// if the begin-norm-axis is 2, the shape of second output is: [A, B, 1, 1]
// if the begin-norm-axis is 3, the shape of second output is: [A, B, C, 1]
// the shape of third output is the same as the shape of second output
Status LayerNormInfo::GetAttrs() {
auto iter = attrs_.find(BEGIN_NORM_AXIS);
if (iter == attrs_.end()) {
@ -113,7 +120,7 @@ Status LayerNormInfo::InferDevMatrixShape() {
return SUCCESS;
}
Status LayerNormInfo::CreateTensorMap(size_t input_index) {
Status LayerNormInfo::CreateInputTensorMap(size_t input_index) {
if (inputs_shape_.size() <= input_index) {
MS_LOG(ERROR) << name_ << ": Invalid index" << input_index;
return FAILED;
@ -124,74 +131,27 @@ Status LayerNormInfo::CreateTensorMap(size_t input_index) {
tensor_map.push_back(SizeToLong(shape.size() - i - 1));
}
inputs_tensor_map_.push_back(tensor_map);
outputs_tensor_map_.push_back(tensor_map);
return SUCCESS;
}
Status LayerNormInfo::InferTensorMap() {
if ((CreateTensorMap(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateTensorMap(LAYER_NORM_GAMMA_INDEX) != SUCCESS) ||
(CreateTensorMap(LAYER_NORM_BETA_INDEX) != SUCCESS)) {
MS_LOG(ERROR) << name_ << ": Create tensor map failed";
return FAILED;
}
return SUCCESS;
}
Status LayerNormInfo::CreateMirrorOp(size_t input_index) {
if (inputs_tensor_map_.size() <= input_index) {
MS_LOG(ERROR) << name_ << ": Invalid index " << input_index;
return FAILED;
}
Shape tensor_map = inputs_tensor_map_[input_index];
std::vector<Group> group;
if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group for input " << input_index << " failed";
return FAILED;
}
OperatorVector mirror_op;
if (!group.empty()) {
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
MS_LOG(INFO) << name_ << " : Create the mirror ops for input " << input_index << " success, group is "
<< group[0].name();
}
mirror_ops_.push_back(mirror_op);
return SUCCESS;
}
Status LayerNormInfo::InferMirrorOps() {
if ((CreateMirrorOp(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateMirrorOp(LAYER_NORM_GAMMA_INDEX) != SUCCESS) ||
(CreateMirrorOp(LAYER_NORM_BETA_INDEX) != SUCCESS)) {
MS_LOG(ERROR) << name_ << ": Create mirror op failed";
return FAILED;
}
return SUCCESS;
}
Status LayerNormInfo::CreateTensorInfo(size_t input_index) {
if ((inputs_shape_.size() <= input_index) || (inputs_tensor_map_.size() <= input_index)) {
MS_LOG(ERROR) << name_ << ": Invalid input index" << input_index;
return FAILED;
}
Shape tensor_map = inputs_tensor_map_[input_index];
Shape shape = inputs_shape_[input_index];
TensorLayout tensor_layout;
if (tensor_layout.InitFromVector(dev_matrix_shape_, tensor_map, shape) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init tensor layout for input " << input_index << " failed";
if ((CreateInputTensorMap(LAYER_NORM_INPUT_INDEX) != SUCCESS) ||
(CreateInputTensorMap(LAYER_NORM_GAMMA_INDEX) != SUCCESS) ||
(CreateInputTensorMap(LAYER_NORM_BETA_INDEX) != SUCCESS)) {
MS_LOG(ERROR) << name_ << ": Create input tensor map failed";
return FAILED;
}
TensorInfo tensor_info(tensor_layout);
inputs_tensor_info_.push_back(tensor_info);
outputs_tensor_info_.push_back(tensor_info);
return SUCCESS;
}
Status LayerNormInfo::InferTensorInfo() {
if ((CreateTensorInfo(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateTensorInfo(LAYER_NORM_GAMMA_INDEX) != SUCCESS) ||
(CreateTensorInfo(LAYER_NORM_BETA_INDEX) != SUCCESS)) {
MS_LOG(ERROR) << name_ << ": Create tensor info failed";
return FAILED;
Shape first_output_tensor_map = inputs_tensor_map_[0];
Shape second_output_tensor_map = first_output_tensor_map;
for (size_t i = begin_norm_axis_; i < second_output_tensor_map.size(); ++i) {
second_output_tensor_map[i] = MAP_NONE;
}
Shape third_output_tensor_map = second_output_tensor_map;
outputs_tensor_map_.push_back(first_output_tensor_map);
outputs_tensor_map_.push_back(second_output_tensor_map);
outputs_tensor_map_.push_back(third_output_tensor_map);
return SUCCESS;
}

View File

@ -52,15 +52,11 @@ class LayerNormInfo : public OperatorInfo {
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferAsLossDivisor() override;
Status CreateTensorMap(size_t input_index);
Status CreateTensorInfo(size_t input_index);
Status CreateMirrorOp(size_t input_index);
Status CreateInputTensorMap(size_t input_index);
Status GenerateGammaAndBetaStrategies(const std::vector<StrategyPtr> &sp_vector);
Status InitShapes();

View File

@ -92,41 +92,6 @@ Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorMap() {
return SUCCESS;
}
Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorInfo() {
// infer tensor shape
Shape input_shape = inputs_shape_.at(0);
Shape first_output_shape = outputs_shape_.at(0);
// infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim();
Strategys outputs_strategy = {{inputs_strategy[0][0]}, inputs_strategy.at(0)};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
return FAILED;
}
Shape input_slice_shape = inputs_slice_shape.at(0);
Shape first_output_slice_shape = outputs_slice_shape.at(0);
TensorMap input_tensor_map = inputs_tensor_map_.at(0);
TensorMap first_output_tensor_map = outputs_tensor_map_.at(0);
TensorLayout input_tensor_layout, first_output_tensor_layout;
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, input_tensor_map, input_shape) != SUCCESS) ||
(first_output_tensor_layout.InitFromVector(dev_matrix_shape_, first_output_tensor_map, first_output_shape) !=
SUCCESS)) {
return FAILED;
}
TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
TensorInfo first_output_tensor_info(first_output_tensor_layout, first_output_shape, first_output_slice_shape);
inputs_tensor_info_.push_back(input_tensor_info); // input
inputs_tensor_info_.push_back(input_tensor_info); // label
outputs_tensor_info_.push_back(first_output_tensor_info); // output-0
outputs_tensor_info_.push_back(input_tensor_info); // output-1
return SUCCESS;
}
// There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload the function.
Status SoftmaxCrossEntropyWithLogitsInfo::InferAsLossDivisor() {
if (outputs_tensor_map_.size() != 2) {

View File

@ -50,7 +50,6 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
Status GetAttrs() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorMap() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
// There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload
// the InferAsLossDivisor.

View File

@ -50,7 +50,7 @@ class MatMulBase : public OperatorInfo {
protected:
Status InferForwardCommunication() override;
Status InferTensorInfo() override;
Status InferTensorInfo() override; // the forward_reduce_scatter mode need to override this function
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout);

View File

@ -112,26 +112,6 @@ Status OneHotInfo::InferTensorMap() {
// (0, (8,2),(),()16 devices two machinesmodel parallel among devices in the same machinedata parallel between
// machines dev_matrix=(2,8) map_in=(1) map_out=(0,1) 04,2,(),()16 devices dev_matrix=(2,4,2) map_in=(1)
// map_out=(0,1)
Status OneHotInfo::InferTensorInfo() {
// infer tensor shape
Shape input_shape = inputs_shape_.at(0);
Shape output_shape = outputs_shape_.at(0);
TensorLayout input_tensor_layout, output_tensor_layout;
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) ||
(output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) {
return FAILED;
}
TensorInfo input_tensor_info(input_tensor_layout);
TensorInfo output_tensor_info(output_tensor_layout);
inputs_tensor_info_.push_back(input_tensor_info);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
Status OneHotInfo::ExtractInputInfo() {
CheckGlobalDeviceManager();
rank_ = g_device_manager->rank_index_in_stage();

View File

@ -48,7 +48,6 @@ class OneHotInfo : public OperatorInfo {
Status GetAttrs() override;
Status InferMirrorOps() override { return SUCCESS; }
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status ExtractInputInfo();

View File

@ -199,6 +199,35 @@ Status OperatorInfo::InferMirrorOps() {
return SUCCESS;
}
Status OperatorInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) {
TensorLayout input_layout;
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed, the index is " << i;
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
}
for (size_t i = 0; i < outputs_tensor_map_.size(); ++i) {
TensorLayout output_layout;
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[i], outputs_shape_[i]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed, the index is " << i;
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info);
}
return SUCCESS;
}
Status OperatorInfo::InferRepeatedCalcInfo() {
int64_t g_dev_list_size = stage_device_size_;
int64_t dev_matrix_size =
@ -278,11 +307,9 @@ Operator CreateVirtualDivOp(int64_t div_num) {
return op;
}
// use for forward all reduce
Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group) {
OperatorName operator_name = ALL_REDUCE;
ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM
ValuePtr attr1_value = MakeValue(group); // group
static OperatorArgs CreateReduceCommunicationOpArgs(const std::string &reduce_op, const std::string &group) {
ValuePtr attr0_value = MakeValue(reduce_op);
ValuePtr attr1_value = MakeValue(group);
Attr attr0 = std::make_pair(OP, attr0_value);
Attr attr1 = std::make_pair(GROUP, attr1_value);
OperatorAttrs operator_attrs;
@ -290,7 +317,13 @@ Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &grou
operator_attrs.push_back(attr1);
OperatorParams operator_param;
OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
return std::make_pair(operator_attrs, operator_param);
}
// use for forward all reduce
Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group) {
OperatorName operator_name = ALL_REDUCE;
OperatorArgs operator_arg = CreateReduceCommunicationOpArgs(reduce_op, group);
Operator op = std::make_pair(operator_name, operator_arg);
MS_LOG(INFO) << "Create all reduce op success, the reduce_op is " << reduce_op << ", the group is " << group;
@ -299,16 +332,7 @@ Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &grou
Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group) {
OperatorName operator_name = REDUCE_SCATTER;
ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM
ValuePtr attr1_value = MakeValue(group); // group
Attr attr0 = std::make_pair(OP, attr0_value);
Attr attr1 = std::make_pair(GROUP, attr1_value);
OperatorAttrs operator_attrs;
operator_attrs.push_back(attr0);
operator_attrs.push_back(attr1);
OperatorParams operator_param;
OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
OperatorArgs operator_arg = CreateReduceCommunicationOpArgs(reduce_op, group);
Operator op = std::make_pair(operator_name, operator_arg);
MS_LOG(INFO) << "Create reduce scatter op success, the reduce_op is " << reduce_op << ", the group is " << group;

View File

@ -190,9 +190,9 @@ class OperatorInfo {
virtual Status InferTensorMap() = 0;
virtual Status InferForwardCommunication() = 0;
virtual Status GetAttrs() = 0;
virtual Status InferTensorInfo() = 0;
virtual Status InferDevMatrixShape() = 0;
virtual Status InferMirrorOps();
virtual Status InferTensorInfo();
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
void SetRepeatedCalcDevMatrix();
void ResetTensorMapIfRepeatedCalc();

View File

@ -119,60 +119,6 @@ Status StackInfo::InferTensorMap() {
return SUCCESS;
}
Status StackInfo::InferMirrorOps() {
mirror_ops_.clear();
if (inputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
return FAILED;
}
Shape input_tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group for input failed.";
return FAILED;
}
if (group.empty()) {
MS_LOG(INFO) << name_ << ": The mirror group is empty.";
return SUCCESS;
}
OperatorVector input_op;
input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
mirror_ops_.push_back(input_op);
}
return SUCCESS;
}
Status StackInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout;
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
}
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
void StackInfo::ReComputeBatchSplitFlagList() {
for (size_t i = 0; i < inputs_shape_.size(); i++) {
split_flag_list_[i] = true;

View File

@ -45,9 +45,7 @@ class StackInfo : public OperatorInfo {
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;

View File

@ -61,8 +61,6 @@ Status PReLUInfo::InferDevMatrixShape() {
return SUCCESS;
}
Status PReLUInfo::InferForwardCommunication() { return SUCCESS; }
/*
* the output tensor map is the same as the input tensor map
*/
@ -85,64 +83,6 @@ Status PReLUInfo::InferTensorMap() {
return SUCCESS;
}
Dimensions PReLUInfo::GetOutputStrategy() {
Dimensions output_strategy = input_strategy_;
return output_strategy;
}
Status PReLUInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) {
if (inputs_layout == nullptr || outputs_layout == nullptr) {
MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null.";
return FAILED;
}
TensorLayout input_layout, param_layout, output_layout;
if ((input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) ||
(param_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) ||
(output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) {
return FAILED;
}
inputs_layout->push_back(input_layout);
inputs_layout->push_back(param_layout);
outputs_layout->push_back(output_layout);
return SUCCESS;
}
Status PReLUInfo::InferTensorInfo() {
// infer tensor shape
Shape input_shape = inputs_shape_.at(0);
Shape param_shape = inputs_shape_.at(1);
Shape output_shape = outputs_shape_.at(0);
// infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape;
Dimensions output_strategy = GetOutputStrategy();
Strategys inputs_strategy = strategy_->GetInputDim();
Strategys outputs_strategy = {output_strategy};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
return FAILED;
}
Shape input_slice_shape = inputs_slice_shape.at(0);
Shape param_slice_shape = inputs_slice_shape.at(1);
Shape output_slice_shape = outputs_slice_shape.at(0);
// infer tensor layout
TensorLayouts inputs_layout, outputs_layout;
if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) {
return FAILED;
}
TensorLayout input_layout = inputs_layout.at(0);
TensorLayout param_layout = inputs_layout.at(1);
TensorLayout output_layout = outputs_layout.at(0);
TensorInfo input_tensor_info(input_layout, input_shape, input_slice_shape);
TensorInfo param_tensor_info(param_layout, param_shape, param_slice_shape);
TensorInfo output_tensor_info(output_layout, output_shape, output_slice_shape);
inputs_tensor_info_.push_back(input_tensor_info);
inputs_tensor_info_.push_back(param_tensor_info);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
Status PReLUInfo::GetAttrs() {
if ((inputs_shape_.size() != PRELU_INPUTS_SIZE) || (outputs_shape_.size() != PRELU_OUTPUTS_SIZE)) {
MS_LOG(ERROR) << name_ << ": Inputs shape size " << inputs_shape_.size() << " or outputs shape size "

View File

@ -45,13 +45,10 @@ class PReLUInfo : public OperatorInfo {
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override;
Status InferTensorInfo() override;
Status InferForwardCommunication() { return SUCCESS; }
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout);
Status GetAttrs() override;
Dimensions GetOutputStrategy();
private:
Dimensions input_strategy_;

View File

@ -83,37 +83,6 @@ Status RangeInfo::InferTensorMap() {
return SUCCESS;
}
Status RangeInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout;
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
}
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info);
for (auto &tensor_info : inputs_tensor_info_) {
MS_LOG(INFO) << name_ << ": The input layout: " << tensor_info.tensor_layout().ToString();
}
MS_LOG(INFO) << name_ << ": The output layout: " << outputs_tensor_info_[0].tensor_layout().ToString();
return SUCCESS;
}
Status RangeInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed";

View File

@ -52,7 +52,6 @@ class RangeInfo : public OperatorInfo {
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;

View File

@ -81,30 +81,6 @@ Status ReLUV2Info::InferDevMatrixShape() {
return SUCCESS;
}
Status ReLUV2Info::InferMirrorOps() {
mirror_ops_.clear();
Shape tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group failed.";
return FAILED;
}
OperatorVector mirror_op;
if (group.empty()) {
MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
return SUCCESS;
} else {
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
mirror_ops_.push_back(mirror_op);
std::string group_name = group[0].name();
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
}
return SUCCESS;
}
Status ReLUV2Info::InferForwardCommunication() {
// do nothing
return SUCCESS;
@ -129,37 +105,6 @@ Status ReLUV2Info::InferTensorMap() {
return SUCCESS;
}
Status ReLUV2Info::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout, mask_layout;
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
if (mask_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[1], outputs_shape_[1]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
TensorInfo mask_tensor_info(mask_layout);
// output and mask
outputs_tensor_info_.push_back(output_tensor_info);
outputs_tensor_info_.push_back(mask_tensor_info);
return SUCCESS;
}
Status ReLUV2Info::InferAsLossDivisor() {
if (!ParallelContext::GetInstance()->loss_repeated_mean()) {
as_loss_divisor_ = 1;

View File

@ -30,7 +30,7 @@
namespace mindspore {
namespace parallel {
/*
* The second dimension is not splitable, as mask is caculated along it.
* The second dimension is not splitable, as mask is calculated along it.
* The input and output have the same tensormap (3, 2, 1, 0), mask's tensormap is (3, 2, 1, 0, -1)
*/
class ReLUV2Info : public OperatorInfo {
@ -46,10 +46,8 @@ class ReLUV2Info : public OperatorInfo {
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
protected:
Status InferMirrorOps() override;
Status InferForwardCommunication() override;
Status InferTensorMap() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status GetAttrs() override;

View File

@ -130,32 +130,6 @@ Status ScatterUpdateInfo::InferTensorMap() {
return SUCCESS;
}
Status ScatterUpdateInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout;
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
}
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
void ScatterUpdateInfo::ReComputeBatchSplitFlagList() {
for (size_t i = 0; i < inputs_shape_.size(); i++) {
split_flag_list_[i] = false; // the first dimension can not be split

View File

@ -47,7 +47,6 @@ class ScatterUpdateInfo : public OperatorInfo {
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override { return SUCCESS; } // the scatter_update only use in eval/predict
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
};

View File

@ -91,32 +91,6 @@ Status SelectInfo::InferTensorMap() {
return SUCCESS;
}
Status SelectInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout;
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
}
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
void SelectInfo::ReComputeBatchSplitFlagList() {
for (size_t i = 0; i < inputs_shape_.size(); i++) {
split_flag_list_[i] = true;

View File

@ -46,7 +46,6 @@ class SelectInfo : public OperatorInfo {
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
};

View File

@ -148,31 +148,6 @@ Status SliceInfo::InferMirrorOps() {
return SUCCESS;
}
Status SliceInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
// infer tensor layout
TensorLayout input_layout, output_layout;
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
TensorInfo output_tensor_info(output_layout);
inputs_tensor_info_.push_back(input_tensor_info);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
// Note: if the batch dimension is not fully fetched, the batch strategy may not work.
std::shared_ptr<Strategys> SliceInfo::GenerateBatchStrategies() {
split_flag_list_ = {true};

View File

@ -49,7 +49,6 @@ class SliceInfo : public OperatorInfo {
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;

View File

@ -133,61 +133,6 @@ Status SplitInfo::InferTensorMap() {
return SUCCESS;
}
Status SplitInfo::InferMirrorOps() {
mirror_ops_.clear();
if (inputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
return FAILED;
}
Shape input_tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group for input failed.";
return FAILED;
}
OperatorVector mirror_op;
if (group.empty()) {
MS_LOG(INFO) << name_ << ": The mirror group is empty.";
return SUCCESS;
} else {
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
mirror_ops_.push_back(mirror_op);
std::string group_name = group[0].name();
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
}
return SUCCESS;
}
Status SplitInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout;
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
for (size_t i = 0; i < outputs_shape_.size(); ++i) {
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info);
}
return SUCCESS;
}
Status SplitInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
Status SplitInfo::GenerateStrategies(int64_t stage_id) {

View File

@ -43,9 +43,7 @@ class SplitInfo : public OperatorInfo {
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferAsLossDivisor() override;

View File

@ -201,30 +201,6 @@ Status StridedSliceInfo::InferMirrorOps() {
return SUCCESS;
}
Status StridedSliceInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
// infer tensor layout
TensorLayout input_layout, output_layout;
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
TensorInfo output_tensor_info(output_layout);
inputs_tensor_info_.push_back(input_tensor_info);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
// Note: if the batch dimension is not fully fetched, the batch strategy may not work.
std::shared_ptr<Strategys> StridedSliceInfo::GenerateBatchStrategies() {
split_flag_list_ = {true};

View File

@ -48,7 +48,6 @@ class StridedSliceInfo : public OperatorInfo {
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetMask(const std::string &mask_name, int64_t *mask_value);

View File

@ -144,14 +144,14 @@ Status TensorDotInfo::CheckStrategy(const StrategyPtr &strategy) {
if (axes_type_ == INT_TYPE) { // for example: axes = 3, [a, b, c, d] and [b, c, d, e]
for (int32_t i = 0; i < axes_int_; ++i) {
if (input_a_strategy[input_a_strategy.size() - axes_int_ + i] != input_b_strategy[i]) {
MS_LOG(ERROR) << name_ << ": The strategies of relavent dimensions are no equal";
MS_LOG(ERROR) << name_ << ": The strategies of relevant dimensions are no equal";
return FAILED;
}
}
} else if (axes_type_ == TUPLE_TUPLE_TYPE) {
for (size_t i = 0; i < axes_tuple_tuple_[0].size(); ++i) {
if (input_a_strategy[axes_tuple_tuple_[0][i]] != input_b_strategy[axes_tuple_tuple_[1][i]]) {
MS_LOG(ERROR) << name_ << ": The strategies of relavent dimensions are no equal";
MS_LOG(ERROR) << name_ << ": The strategies of relevant dimensions are no equal";
return FAILED;
}
}
@ -196,42 +196,10 @@ Status TensorDotInfo::InferDevMatrixShape() {
return SUCCESS;
}
Status TensorDotInfo::InferMirrorOps() {
mirror_ops_.clear();
Shape input_a_tensor_map = inputs_tensor_map_[0];
Shape input_b_tensor_map = inputs_tensor_map_[1];
std::vector<Group> input_a_group, input_b_group;
if ((CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) ||
(CreateGroupByTensorMap(input_b_tensor_map, &input_b_group) != SUCCESS)) {
MS_LOG(ERROR) << name_ << ": Create group by tensor map failed";
return FAILED;
}
if (input_a_group.empty() && input_b_group.empty()) {
MS_LOG(INFO) << name_ << ": The mirror ops is empty";
return SUCCESS;
}
OperatorVector op_for_input_a, op_for_input_b;
if (!input_a_group.empty()) {
op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum());
MS_LOG(INFO) << name_ << ": Create the mirror ops for input_a success, group is " << input_a_group[0].name();
}
if (!input_b_group.empty()) {
op_for_input_b = CreateMirrorOps(input_b_group[0].name(), input_b_group[0].GetDevNum());
MS_LOG(INFO) << name_ << ": Create the mirror ops for input_b success, group is " << input_b_group[0].name();
}
mirror_ops_.push_back(op_for_input_a);
mirror_ops_.push_back(op_for_input_b);
return SUCCESS;
}
Status TensorDotInfo::InferForwardCommunication() {
forward_op_.clear();
Shape forward_group_map = outputs_tensor_map_[0];
// handel the repeat calculation, the forward communication's group can not include the dimension of repeat
// handle the repeat calculation, the forward communication's group can not include the dimension of repeat
// calculation
if (repeated_calc_num_ > 1) {
if (repeated_num_in_dev_matrix_right_) {
@ -353,37 +321,6 @@ Status TensorDotInfo::InferTensorMap() {
return SUCCESS;
}
Status TensorDotInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout;
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
}
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info);
for (size_t i = 0; i < inputs_tensor_info_.size(); i++) {
MS_LOG(INFO) << name_ << ": The input " << i << " layout: " << inputs_tensor_info_[i].tensor_layout().ToString();
}
MS_LOG(INFO) << name_ << ": The output layout: " << outputs_tensor_info_[0].tensor_layout().ToString();
return SUCCESS;
}
Status TensorDotInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed";
@ -415,23 +352,23 @@ std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() {
if (axes_type_ == INT_TYPE) {
if (IntToSize(axes_int_) == inputs_shape_[0].size()) {
input_b_strategy[0] = stage_device_size_; // find the relavent dimension for input_b
input_b_strategy[0] = stage_device_size_; // find the relevant dimension for input_b
}
} else if (axes_type_ == TUPLE_TUPLE_TYPE) {
// if the input_a's axes contain 0, the input_b has the relavent dimension with batch dimension
// if the input_a's axes contain 0, the input_b has the relevant dimension with batch dimension
bool found = false;
size_t relavant_index = 0;
size_t relevant_index = 0;
for (size_t i = 0; i < axes_tuple_tuple_[0].size(); ++i) {
if (axes_tuple_tuple_[0][i] == 0) {
found = true;
relavant_index = i;
relevant_index = i;
break;
}
}
if (found) {
// find the relavant
input_b_strategy[axes_tuple_tuple_[1][relavant_index]] = stage_device_size_;
// find the relevant
input_b_strategy[axes_tuple_tuple_[1][relevant_index]] = stage_device_size_;
}
} else {
MS_LOG(EXCEPTION) << name_ << ": Now do not support TUPLE_TYPE";

View File

@ -54,9 +54,7 @@ class TensorDotInfo : public OperatorInfo {
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;

View File

@ -155,30 +155,6 @@ Status TileInfo::InferMirrorOps() {
return SUCCESS;
}
Status TileInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
// infer tensor layout
TensorLayout input_layout, output_layout;
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
TensorInfo output_tensor_info(output_layout);
inputs_tensor_info_.push_back(input_tensor_info);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
void TileInfo::UpdateMultiples(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() != 3) {

View File

@ -49,7 +49,6 @@ class TileInfo : public OperatorInfo {
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;

View File

@ -48,32 +48,6 @@ Status TmpIdentityInfo::InferTensorMap() {
return SUCCESS;
}
Status TmpIdentityInfo::InferTensorInfo() {
// infer tensor shape
Shape input_shape = inputs_shape_.at(0);
// infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim();
Strategys outputs_strategy = {inputs_strategy.at(0)};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
return FAILED;
}
Shape input_slice_shape = inputs_slice_shape.at(0);
TensorLayout input_tensor_layout;
if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) {
return FAILED;
}
TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
inputs_tensor_info_.push_back(input_tensor_info);
outputs_tensor_info_.push_back(input_tensor_info); // the same as input
return SUCCESS;
}
Status TmpIdentityInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";

View File

@ -48,7 +48,6 @@ class TmpIdentityInfo : public OperatorInfo {
Status GetAttrs() override { return SUCCESS; }
Status InferMirrorOps() override { return SUCCESS; }
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
};

View File

@ -82,33 +82,6 @@ Status TopKInfo::InferTensorMap() {
return SUCCESS;
}
Status TopKInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout;
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
}
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info); // values
outputs_tensor_info_.push_back(output_tensor_info); // indices
return SUCCESS;
}
Status TopKInfo::InferAsLossDivisor() {
if (outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty";

View File

@ -47,7 +47,6 @@ class TopKInfo : public OperatorInfo {
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferMirrorOps() override; // can not use OperatorInfo::InferMirrorOps(), since the 'k' of topk is scalar
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferAsLossDivisor() override;

View File

@ -108,63 +108,6 @@ Status TransposeInfo::InferTensorMap() {
return SUCCESS;
}
// the output tensor strategy is the permutation of input tensor strategy, the permutation is axis_v
Strategys TransposeInfo::GetOutputsStrategy() {
Strategys outputs_strategy;
Dimensions strategy = input_strategy_;
for (uint64_t i = 0; i < strategy.size(); i++) {
strategy[i] = input_strategy_[LongToUlong(axis_v_[i])];
}
outputs_strategy.push_back(strategy);
return outputs_strategy;
}
Status TransposeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) {
if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) {
MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null.";
return FAILED;
}
Shape shape_in = inputs_shape_.at(0);
TensorMap tensor_map_in = inputs_tensor_map_.at(0);
Shape shape_out = outputs_shape_.at(0);
TensorMap tensor_map_out = outputs_tensor_map_.at(0);
TensorLayout tensor_layout_in, tensor_layout_out;
if ((tensor_layout_in.InitFromVector(dev_matrix_shape_, tensor_map_in, shape_in) != SUCCESS) ||
(tensor_layout_out.InitFromVector(dev_matrix_shape_, tensor_map_out, shape_out) != SUCCESS)) {
return FAILED;
}
inputs_layout->push_back(tensor_layout_in);
outputs_layout->push_back(tensor_layout_out);
return SUCCESS;
}
Status TransposeInfo::InferTensorInfo() {
Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim();
Strategys outputs_strategy = GetOutputsStrategy();
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
return FAILED;
}
TensorLayouts inputs_layout, outputs_layout;
if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) {
return FAILED;
}
TensorLayout tensor_layout_in = inputs_layout.at(0);
TensorLayout tensor_layout_out = outputs_layout.at(0);
Shape shape_array_in = inputs_shape_.at(0);
Shape slice_shape_in = inputs_slice_shape.at(0);
Shape shape_array_out = outputs_shape_.at(0);
Shape slice_shape_out = outputs_slice_shape.at(0);
TensorInfo tensor_info_in(tensor_layout_in, shape_array_in, slice_shape_in);
TensorInfo tensor_info_out(tensor_layout_out, shape_array_out, slice_shape_out);
inputs_tensor_info_.push_back(tensor_info_in);
outputs_tensor_info_.push_back(tensor_info_out);
return SUCCESS;
}
// compute axis_v_ during this method
Status TransposeInfo::GetAttrs() { return ComputeAxis(); }

View File

@ -46,12 +46,9 @@ class TransposeInfo : public OperatorInfo {
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout);
Status GetAttrs() override;
Strategys GetOutputsStrategy();
private:
Status ComputeAxis();

View File

@ -147,61 +147,6 @@ Status UniformCandidateSamplerInfo::InferAsLossDivisor() {
return SUCCESS;
}
Status UniformCandidateSamplerInfo::InferMirrorOps() {
mirror_ops_.clear();
if (inputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
return FAILED;
}
Shape input_tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group for input failed.";
return FAILED;
}
OperatorVector mirror_op;
if (group.empty()) {
MS_LOG(INFO) << name_ << ": The mirror group is empty.";
return SUCCESS;
} else {
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
mirror_ops_.push_back(mirror_op);
std::string group_name = group[0].name();
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
}
return SUCCESS;
}
Status UniformCandidateSamplerInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout;
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
for (size_t i = 0; i < outputs_shape_.size(); ++i) {
// infer tensor layout
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[i], outputs_shape_[i]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info);
}
return SUCCESS;
}
Status UniformCandidateSamplerInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
return SetCostUnderStrategyBase(strategy);
}

View File

@ -53,9 +53,7 @@ class UniformCandidateSamplerInfo : public OperatorInfo {
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status ComputeReplaceGraph(const CNodePtr &cnode);

View File

@ -53,42 +53,6 @@ Status UniqueInfo::InferTensorMap() {
return SUCCESS;
}
Status UniqueInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) {
if (inputs_layout == nullptr || outputs_layout == nullptr) {
MS_LOG(ERROR) << name_ << " : The layout is null.";
return FAILED;
}
TensorLayout input_layout;
TensorLayout output_layout;
TensorLayout index_layout;
if ((input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) ||
(output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) ||
(index_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[1], outputs_shape_[1]) != SUCCESS)) {
return FAILED;
}
inputs_layout->push_back(input_layout);
outputs_layout->push_back(output_layout);
outputs_layout->push_back(index_layout);
return SUCCESS;
}
Status UniqueInfo::InferTensorInfo() {
TensorLayouts inputs_layout;
TensorLayouts outputs_layout;
if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) {
return FAILED;
}
for (size_t i = 0; i < inputs_layout.size(); ++i) {
TensorInfo input_tensor_info(inputs_layout[i]);
inputs_tensor_info_.push_back(input_tensor_info);
}
for (size_t i = 0; i < outputs_layout.size(); ++i) {
TensorInfo output_tensor_info(outputs_layout[i]);
outputs_tensor_info_.push_back(output_tensor_info);
}
return SUCCESS;
}
Status UniqueInfo::InferDevMatrixShape() {
dev_matrix_shape_.push_back(stage_device_size_);
return SUCCESS;
@ -132,29 +96,6 @@ Status UniqueInfo::GetAttrs() {
return SUCCESS;
}
Status UniqueInfo::InferMirrorOps() {
mirror_ops_.clear();
Shape tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group failed.";
return FAILED;
}
OperatorVector mirror_op;
if (group.empty()) {
MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
return SUCCESS;
} else {
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
mirror_ops_.push_back(mirror_op);
std::string group_name = group[0].name();
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
}
return SUCCESS;
}
Status UniqueInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init for cost model failed.";

View File

@ -46,9 +46,7 @@ class UniqueInfo : public OperatorInfo {
Status GetAttrs() override;
Status InferTensorMap() override;
Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout);
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferAsLossDivisor() override { return SUCCESS; }
#if (ENABLE_CPU && !_WIN32)

View File

@ -167,29 +167,6 @@ Status UnsortedSegmentOpInfo::InferTensorMap() {
return SUCCESS;
}
Status UnsortedSegmentOpInfo::InferTensorInfo() {
// infer tensor shape
Shape input_shape = inputs_shape_.at(0);
Shape input_index_shape = inputs_shape_.at(1);
Shape output_shape = outputs_shape_.at(0);
TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) ||
(input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) ||
(output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) {
return FAILED;
}
TensorInfo input_tensor_info(input_tensor_layout);
TensorInfo input_index_info(input_index_layout);
TensorInfo output_tensor_info(output_tensor_layout);
inputs_tensor_info_.push_back(input_tensor_info);
inputs_tensor_info_.push_back(input_index_info);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
Status UnsortedSegmentOpInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";

View File

@ -48,7 +48,6 @@ class UnsortedSegmentOpInfo : public OperatorInfo {
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override;
Status InferMirrorOps() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;

View File

@ -102,20 +102,6 @@ Status VirtualDatasetInfo::InferTensorMap() {
return SUCCESS;
}
Status VirtualDatasetInfo::InferTensorInfo() {
for (size_t i = 0; i < strategy_->GetInputNumber(); i++) {
MS_LOG(INFO) << name_ << ": InferTensorInfo " << i << ", size " << strategy_->GetInputNumber();
TensorLayout tensor_layout_in;
if (tensor_layout_in.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(i), inputs_shape_.at(i)) != SUCCESS) {
return FAILED;
}
TensorInfo tensor_info_in(tensor_layout_in);
inputs_tensor_info_.push_back(tensor_info_in);
outputs_tensor_info_.push_back(tensor_info_in);
}
return SUCCESS;
}
Status VirtualDatasetInfo::GetAttrs() { return SUCCESS; }
Status VirtualDatasetInfo::Init(const StrategyPtr &strategy) {

View File

@ -45,7 +45,6 @@ class VirtualDatasetInfo : public OperatorInfo {
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;

View File

@ -1,122 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <list>
#include <vector>
#include "common/common_test.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/ops_info/get_next_info.h"
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore {
namespace parallel {
class GetNextInfo;
using GetNextInfoPtr = std::shared_ptr<GetNextInfo>;
GetNextInfoPtr get_next;
class TestGetNextInfo : public UT::Common {
public:
TestGetNextInfo() {}
void SetUp();
void TearDown() {}
};
void TestGetNextInfo::SetUp() {
RankList dev_list;
for (int32_t i = 0; i < 8; i++) {
dev_list.push_back(i);
}
RankList stage_map;
stage_map.push_back(8);
int32_t local_dev = 0;
// create a new g_device_manager
g_device_manager = std::make_shared<DeviceManager>();
g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
Shapes inputs_shape = {};
Shapes outputs_shape = {{64, 32}, {64}};
std::unordered_map<std::string, ValuePtr> attr;
std::vector<std::string> types_ = {"float32", "int32"};
Shapes shapes_ = {{64, 32}, {64}};
int64_t output_num_ = 2;
std::string shared_name_ = "test_get_next";
attr["types"] = MakeValue(types_);
attr["shapes"] = MakeValue(shapes_);
attr["output_num"] = MakeValue(output_num_);
attr["shared_name"] = MakeValue(shared_name_);
get_next = std::make_shared<GetNextInfo>("get_next_info", inputs_shape, outputs_shape, attr);
}
TEST_F(TestGetNextInfo, InferDevMatrixShape1) {
Strategys inputs = {{}, {}};
StrategyPtr strategy = NewStrategy(0, inputs);
get_next->Init(strategy);
Shape dev_matrix_shape = get_next->dev_matrix_shape();
Shape expect = {8, 1};
ASSERT_EQ(dev_matrix_shape, expect);
}
TEST_F(TestGetNextInfo, InferSliceShape1) {
Strategys str = {{}, {}};
StrategyPtr strategy = NewStrategy(0, str);
get_next->Init(strategy);
std::vector<TensorInfo> outputs = get_next->outputs_tensor_info();
Shape output_slice_shape_expect0 = {8, 32};
Shape output_slice_shape_expect1 = {8};
TensorInfo output_tensor_info0 = outputs.at(0);
TensorInfo output_tensor_info1 = outputs.at(1);
Shape output_slice_shape0 = output_tensor_info0.slice_shape();
Shape output_slice_shape1 = output_tensor_info1.slice_shape();
ASSERT_EQ(output_slice_shape0, output_slice_shape_expect0);
ASSERT_EQ(output_slice_shape1, output_slice_shape_expect1);
}
TEST_F(TestGetNextInfo, GetTensorLayout1) {
Strategys str = {{}, {}};
StrategyPtr strategy = NewStrategy(0, str);
get_next->Init(strategy);
std::vector<TensorInfo> outputs = get_next->outputs_tensor_info();
TensorMap output_expect0 = {1, 0};
TensorMap output_expect1 = {1};
TensorInfo output_tensor_info0 = outputs.at(0);
TensorInfo output_tensor_info1 = outputs.at(1);
Map output_tensor_map0 = output_tensor_info0.tensor_layout().origin_tensor_map();
Map output_tensor_map1 = output_tensor_info1.tensor_layout().origin_tensor_map();
ASSERT_EQ(output_tensor_map0.array(), output_expect0);
ASSERT_EQ(output_tensor_map1.array(), output_expect1);
}
TEST_F(TestGetNextInfo, CheckStrategy1) {
Strategys inputs = {};
StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = get_next->Init(strategy);
ASSERT_EQ(ret, SUCCESS);
}
TEST_F(TestGetNextInfo, CheckStrategy2) {
Strategys inputs = {{8, 1}, {8}};
StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = get_next->Init(strategy);
ASSERT_EQ(ret, FAILED);
}
} // namespace parallel
} // namespace mindspore

View File

@ -43,9 +43,9 @@ class Net(Cell):
return out
_x = Tensor(np.ones([128, 64, 32, 16]), dtype=ms.float32)
_w = Tensor(np.ones([128, 64, 32, 16]), dtype=ms.float32)
_b = Tensor(np.ones([128, 64, 32, 16]), dtype=ms.float32)
_x = Tensor(np.ones([16, 64, 32, 16]), dtype=ms.float32)
_w = Tensor(np.ones([16, 64, 32, 16]), dtype=ms.float32)
_b = Tensor(np.ones([16, 64, 32, 16]), dtype=ms.float32)
def compile_net(net):