From 6f914b8b3c4fba8d571c80e2e7b298f777b0c902 Mon Sep 17 00:00:00 2001 From: liuluobin Date: Thu, 17 Mar 2022 20:28:07 +0800 Subject: [PATCH] Implementation of SquaredDifferenceInfo, ErfinvInfo, MaskedFillInfo, SplitVInfo, GammaInfo, KLDivLossInfo and LinSpaceInfo --- .../auto_parallel/operator_costmodel.h | 6 + .../ccsrc/frontend/parallel/dynamic_creator.h | 7 + .../parallel/graph_util/generate_graph.cc | 13 ++ .../parallel/graph_util/generate_graph.h | 1 + .../frontend/parallel/graph_util/node_info.cc | 16 +- .../parallel/ops_info/activation_info.cc | 29 +++ .../parallel/ops_info/activation_info.h | 13 +- .../parallel/ops_info/arithmetic_info.cc | 34 +++ .../parallel/ops_info/arithmetic_info.h | 25 ++ .../frontend/parallel/ops_info/gamma_info.cc | 218 ++++++++++++++++++ .../frontend/parallel/ops_info/gamma_info.h | 64 +++++ .../parallel/ops_info/kldiv_loss_info.cc | 136 +++++++++++ .../parallel/ops_info/kldiv_loss_info.h | 57 +++++ .../parallel/ops_info/lin_space_info.cc | 176 ++++++++++++++ .../parallel/ops_info/lin_space_info.h | 62 +++++ .../parallel/ops_info/operator_info.cc | 24 ++ .../parallel/ops_info/operator_info.h | 6 +- .../parallel/ops_info/ops_info_head_files.h | 3 + .../frontend/parallel/ops_info/ops_utils.h | 9 + .../parallel/ops_info/reduce_method_info.cc | 26 +-- .../frontend/parallel/ops_info/split_info.cc | 22 +- .../frontend/parallel/ops_info/split_info.h | 13 +- .../parallel/ops_info/uniform_real_info.cc | 54 +++-- .../parallel/ops_info/uniform_real_info.h | 12 +- .../parallel/ops_info/virtual_dataset_info.cc | 16 +- .../frontend/parallel/step_auto_parallel.cc | 2 +- .../ccsrc/frontend/parallel/step_parallel.cc | 9 +- tests/ut/python/parallel/test_arithmetic.py | 86 +++++++ .../parallel/test_element_wise_function.py | 31 +++ tests/ut/python/parallel/test_gamma.py | 94 ++++++++ tests/ut/python/parallel/test_kldiv_loss.py | 127 ++++++++++ tests/ut/python/parallel/test_l2_loss.py | 21 +- tests/ut/python/parallel/test_lin_space.py | 111 +++++++++ tests/ut/python/parallel/test_splitv.py | 118 ++++++++++ tests/ut/python/parallel/test_uniform_real.py | 74 ++++++ 35 files changed, 1629 insertions(+), 86 deletions(-) create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/gamma_info.cc create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/gamma_info.h create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/kldiv_loss_info.cc create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/kldiv_loss_info.h create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/lin_space_info.cc create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/lin_space_info.h create mode 100644 tests/ut/python/parallel/test_gamma.py create mode 100644 tests/ut/python/parallel/test_kldiv_loss.py create mode 100644 tests/ut/python/parallel/test_lin_space.py create mode 100644 tests/ut/python/parallel/test_splitv.py create mode 100644 tests/ut/python/parallel/test_uniform_real.py diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h index de520879ef9..8d0d5bec416 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h @@ -192,6 +192,8 @@ using IOUCost = CastCost; using RandomChoicWithMaskCost = CastCost; using IsFiniteCost = CastCost; using RintCost = CastCost; +using GammaCost = CastCost; +using LinSpaceCost = CastCost; class SqrtCost : public CastCost { public: @@ -247,6 +249,7 @@ using ErfcCost = ReLU6Cost; using ActivationInfoCost = ReLU6Cost; using SelectCost = ReLU6Cost; using XlogyCost = ReLU6Cost; +using ErfinvCost = ReLU6Cost; class TransposeCost : public CastCost { public: @@ -634,6 +637,7 @@ using BitwiseOrCost = SubCost; using BitwiseXorCost = SubCost; using AddNCost = SubCost; using InplaceAddCost = SubCost; +using MaskedFillCost = SubCost; class MulCost : public SubCost { public: @@ -646,6 +650,7 @@ class MulCost : public SubCost { using MulNoNanCost = MulCost; using GatherDCost = MulCost; using LerpCost = MulCost; +using SquaredDifferenceCost = MulCost; class DivCost : public SubCost { public: @@ -777,6 +782,7 @@ using ReduceMethodCost = ReduceSumCost; using ReduceProdCost = ReduceSumCost; using SquareSumAllCost = ReduceSumCost; using L2LossCost = ReduceSumCost; +using KLDivLossCost = ReduceSumCost; class ReduceMeanCost : public ReduceSumCost { public: diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h index 0555fd45d78..d12a63e6a5d 100644 --- a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -243,6 +243,13 @@ REGISTER(InplaceSubInfo); REGISTER(CdistInfo); REGISTER(L2LossInfo); REGISTER(LerpInfo); +REGISTER(SquaredDifferenceInfo); +REGISTER(ErfinvInfo); +REGISTER(MaskedFillInfo); +REGISTER(SplitVInfo); +REGISTER(GammaInfo); +REGISTER(KLDivLossInfo); +REGISTER(LinSpaceInfo); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc index c347d71b889..69c2c8ff17e 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc @@ -91,6 +91,19 @@ AnfNodePtr CreateInt32Tensor(int64_t value) { return anf_node_ptr; } +static mindspore::HashMap float_tensor_map = {}; +AnfNodePtr CreateFP32Tensor(float value) { + auto it = float_tensor_map.find(value); + if (it != float_tensor_map.end()) { + return it->second; + } + mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(value, kFloat32); + ValuePtr value_ptr = MakeValue(tensor_ptr); + auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr); + float_tensor_map[value] = anf_node_ptr; + return anf_node_ptr; +} + AnfNodePtr CreateTypeInt(int64_t nbits) { ValuePtr value_ptr = MakeValue(std::make_shared(nbits)); return ValuePtrToAnfNodePtr(value_ptr); diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h index 2dd8b3a80a2..cc1f59a7ea0 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h +++ b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h @@ -42,6 +42,7 @@ AnfNodePtr CreateTypeFloat(int64_t nbits); AnfNodePtr CreatInt64Imm(int64_t value); AnfNodePtr CreateFP32Imm(float value); AnfNodePtr CreateInt32Tensor(int64_t value); +AnfNodePtr CreateFP32Tensor(float value); AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr); AnfNodePtr CreateTuple(const std::vector &tuple); std::string HashInstanceName(const std::string &name); diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc index 687d14cb2ed..cd6d0377be2 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc @@ -76,7 +76,16 @@ std::vector ExtractInputParameterByNode(const CNodePtr &node) { } else { inputs_seq = node_inputs[1]->cast()->value()->cast()->value(); } - return std::vector(inputs_seq.size(), false); + size_t inputs_seq_tensor_size = inputs_seq.size(); + for (const auto &inputs_seq_value : inputs_seq) { + auto tensor = inputs_seq_value->cast(); + if (tensor == nullptr) { + MS_LOG(DEBUG) << "The value not is not a tensor."; + inputs_seq_tensor_size = 0; + break; + } + } + return std::vector(inputs_seq_tensor_size, false); } if ((node_inputs.size() == 2) && (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { @@ -168,7 +177,10 @@ std::vector ExtractInputTypeLengthByNode(const CNodePtr &node) { } for (auto &ele : inputs_seq) { auto tensor = ele->cast(); - MS_EXCEPTION_IF_NULL(tensor); + if (tensor == nullptr) { + inputs_type_len.clear(); + return inputs_type_len; + } inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype())); } return inputs_type_len; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc index 2c5ac0a0bf8..96e4914c3ee 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc @@ -28,6 +28,7 @@ #include "frontend/parallel/device_matrix.h" #include "frontend/parallel/strategy.h" #include "frontend/parallel/tensor_layout/redistribution_operator_infer.h" +#include "frontend/parallel/graph_util/generate_graph.h" namespace mindspore { namespace parallel { @@ -616,5 +617,33 @@ Status L2LossInfo::InferTensorMap() { outputs_tensor_map_[0].clear(); return SUCCESS; } + +Status L2LossInfo::InferForwardCommunication() { + forward_op_.clear(); + Shape group_create_map; + if (repeated_calc_num_ > 1) { + if (repeated_num_in_dev_matrix_right_) { + group_create_map.push_back(0); + } else { + group_create_map.push_back(SizeToLong(dev_matrix_shape_.size()) - 1); + } + } + + std::vector group_list; + if (CreateGroupByTensorMap(group_create_map, &group_list) != SUCCESS) { + ReportError(name_ + ": Create group failed."); + return FAILED; + } + if (group_list.empty()) { + MS_LOG(INFO) << name_ << ": Forward all reduce is not required"; + return SUCCESS; + } + + Operator op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name()); + forward_op_.push_back(op); + MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list[0].name(); + + return SUCCESS; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h index 18abf29e6c5..3659d475c72 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h @@ -374,13 +374,22 @@ class SoftShrinkInfo : public ActivationOther { class L2LossInfo : public ActivationOther { public: - L2LossInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &output_shape, + L2LossInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, output_shape, attrs, std::make_shared()) {} + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~L2LossInfo() = default; protected: Status InferTensorMap() override; + Status InferForwardCommunication() override; +}; + +class ErfinvInfo : public ActivationOther { + public: + ErfinvInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~ErfinvInfo() override = default; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc index 6c217699ca5..5018a082ad3 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc @@ -353,5 +353,39 @@ void LerpInfo::ReComputeBatchSplitFlagList() { Shape expand_weight_shape = ExpandShape(expand_a_shape, inputs_shape_.at(2)); (expand_weight_shape.at(0) != 1) ? (split_flag_list_[2] = true) : (split_flag_list_[2] = false); } + +Status MaskedFillInfo::GetAttrs() { + input_size_ = inputs_shape_.size(); + if (input_size_ != 2 && input_size_ != 3) { + MS_LOG(ERROR) << name_ << ": inputs_shape_.size() must be 2 or 3, but got size " << input_size_; + return FAILED; + } + return SUCCESS; +} + +Status MaskedFillInfo::InferTensorMap() { + if (ArithmeticBase::InferTensorMap() != SUCCESS) { + return FAILED; + } + + if (input_size_ == 3) { + // append a void tensor map for 0-dimensional tensor input 'value' + (void)inputs_tensor_map_.emplace_back(TensorMap()); + } + return SUCCESS; +} + +std::vector MaskedFillInfo::GenerateOpStrategies(int64_t stage_id) { + auto sp_vector = ArithmeticBase::GenerateOpStrategies(stage_id); + if (input_size_ == 3) { + // append void strategy for input `value` + for (size_t i = 0; i < sp_vector.size(); ++i) { + auto strategies = sp_vector[i]->GetInputDim(); + (void)strategies.emplace_back(Dimensions()); + sp_vector[i] = std::make_shared(stage_id, strategies); + } + } + return sp_vector; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h index 12c55847f10..b1f392db3dd 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h @@ -263,6 +263,31 @@ class LerpInfo : public ArithmeticBase { private: size_t inputs_size_ = 0; }; + +class SquaredDifferenceInfo : public ArithmeticBase { + public: + SquaredDifferenceInfo(const std::string &name, const Shapes &input_shape, const Shapes &output_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, input_shape, output_shape, attrs, std::make_shared()) {} + ~SquaredDifferenceInfo() override = default; +}; + +class MaskedFillInfo : public ArithmeticBase { + public: + MaskedFillInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &output_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, output_shape, attrs, std::make_shared()) {} + ~MaskedFillInfo() override = default; + + std::vector GenerateOpStrategies(int64_t stage_id) override; + + protected: + Status GetAttrs() override; + Status InferTensorMap() override; + + private: + size_t input_size_ = 0; +}; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gamma_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gamma_info.cc new file mode 100644 index 00000000000..d4c71c00e68 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gamma_info.cc @@ -0,0 +1,218 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "frontend/parallel/ops_info/gamma_info.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "frontend/parallel/auto_parallel/edge_costmodel.h" + +namespace mindspore { +namespace parallel { +Status GammaInfo::InferAttrs() { + if (infer_attrs_completed_) { + return SUCCESS; + } + + if (GetAttrs() != SUCCESS) { + return FAILED; + } + ResetInputsShape(); + infer_attrs_completed_ = true; + return SUCCESS; +} + +Status GammaInfo::GetAttrs() { + seed_ = GetIntAttr(SEED); + if (seed_ < 0) { + MS_LOG(ERROR) << name_ << ": Seed must be greater or equal to zero, bug got " << seed_; + return FAILED; + } + seed2_ = GetIntAttr(SEED2); + if (seed2_ < 0) { + MS_LOG(ERROR) << name_ << ": Seed2 must be greater or equal to zero, bug got " << seed2_; + return FAILED; + } + return SUCCESS; +} + +Status GammaInfo::CheckStrategy(const StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + return FAILED; + } + auto strategies = strategy->GetInputDim(); + auto alpha_strategy = strategies.at(1); + auto beta_strategy = strategies.at(2); + if (!IsNotSplittableStrategy(alpha_strategy) || !IsNotSplittableStrategy(beta_strategy)) { + MS_LOG(ERROR) << name_ << ": Cannot shard the input `alpha` and `beta`, but got strategy " + << StrategyToString(strategies); + return FAILED; + } + return SUCCESS; +} + +Status GammaInfo::InferDevMatrixShape() { + dev_matrix_shape_.clear(); + auto strategies = strategy_->GetInputDim(); + auto shape_strategy = strategies.at(0); + dev_matrix_shape_ = shape_strategy; + + MS_LOG(INFO) << name_ << ": The dev matrix is " << ShapeToString(dev_matrix_shape_); + return SUCCESS; +} + +Status GammaInfo::InferTensorMap() { + auto strategies = strategy_->GetInputDim(); + auto shape_strategy = strategies.at(0); + size_t size = shape_strategy.size(); + TensorMap shape_tensor_map; + for (size_t i = 0; i < size; ++i) { + shape_tensor_map.push_back(size - i - 1); + } + + inputs_tensor_map_.push_back(shape_tensor_map); + (void)inputs_tensor_map_.emplace_back(TensorMap(strategies.at(1).size(), -1)); + (void)inputs_tensor_map_.emplace_back(TensorMap(strategies.at(2).size(), -1)); + (void)outputs_tensor_map_.emplace_back(shape_tensor_map); + return SUCCESS; +} + +std::vector GammaInfo::GenerateOpStrategies(int64_t stage_id) { + Shape input0_split(inputs_shape_[0].size(), 1); + Shape input1_split(inputs_shape_[1].size(), 0); + Shape input2_split(inputs_shape_[2].size(), 0); + Shapes splittable_inputs = {input0_split, input1_split, input2_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed."; + } + if (sp_vector.empty()) { + MS_LOG(EXCEPTION) << name_ << ": No available strategy."; + } + return sp_vector; +} + +void GammaInfo::UpdateShape(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto input_node = cnode->input(1)->cast(); + std::vector input_shape = GetValue>(input_node->value()); + std::vector stra = strategy_->GetInputDim(); + for (size_t i = 0; i < stra[0].size(); i++) { + input_shape[i] /= stra[0][i]; + } + auto func_graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + ValuePtr new_shape = MakeValue(input_shape); + AnfNodePtr val = NewValueNode(new_shape); + (void)manager->Replace(cnode->input(1), val); +} + +void GammaInfo::ReplaceNodeInputOrAttrs() { + // Replace input 'shape' to slice shape + auto cnode = cnode_; + UpdateShape(cnode); + + // Update seed according rank_id + int64_t rank_id = g_device_manager->rank_index_in_stage(); + int64_t seed_bias; + if (repeated_num_in_dev_matrix_right_) { + seed_bias = rank_id / repeated_calc_num_; + } else { + int64_t device_num = stage_device_size_; + seed_bias = rank_id % (device_num / repeated_calc_num_); + } + + auto prim = GetValueNode(cnode->input(0)); + prim->set_attr(SEED, MakeValue(seed_ + seed_bias)); + prim->set_attr(SEED2, MakeValue(seed2_ + seed_bias)); +} + +void GammaInfo::ResetInputsShape() { + if (inputs_shape_.size() == input_value_.size()) { + return; + } + + ValueTuplePtr shape_value = input_value_[0]->cast(); + MS_EXCEPTION_IF_NULL(shape_value); + (void)inputs_shape_.insert(inputs_shape_.begin(), GetValue(shape_value)); +} + +bool GammaInfo::IsNotSplittableStrategy(const Dimensions &strategy) { + return std::all_of(strategy.begin(), strategy.end(), [](int64_t val) { return val == 1; }); +} + +void GammaInfo::ReComputeBatchSplitFlagList() { + ResetInputsShape(); + split_flag_list_.clear(); + split_flag_list_ = std::vector(inputs_shape_.size(), false); + if (!split_flag_list_.empty()) { + split_flag_list_[0] = true; + } +} + +int64_t GammaInfo::ComputeOpAndPrevEdgeParameterInvolved() { + if (is_output_parameter_involve_ != -1) { + return is_output_parameter_involve_; + } + is_parameter_involve_ = is_parameter_; + const auto &prev_edges = this->GetAlivePrevEdges(); + for (auto &p_edge : prev_edges) { + auto input_index = p_edge->next_op_input_index(); + auto prev_op_para = p_edge->prev_operator()->ComputeOpAndPrevEdgeParameterInvolved(); + if (input_index - 1 >= is_parameter_involve_.size()) { + MS_LOG(EXCEPTION) << name_ << " has input length: " << is_parameter_involve_.size() + << ", but got wrong input_index: " << input_index; + } + if (prev_op_para == 0) { + is_parameter_involve_[input_index - 1] = false; + } else if (prev_op_para == 1) { + is_parameter_involve_[input_index - 1] = true; + } else { + MS_LOG(EXCEPTION) << name_ << " got wrong value: " << prev_op_para << ", input_index: " << input_index; + } + p_edge->set_parameter_involve(prev_op_para); + } + if (std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; })) { + // If anyone of the input is a parameter_involved, the output is parameter_involved. + is_output_parameter_involve_ = 1; + } else { + is_output_parameter_involve_ = 0; + } + // Set 'is_parameter_involve_' and 'is_output_parameter_involve_' into operatorCost, which are used in + // calculating 'inputs_in_memory' and 'output_in_memory', respectively. + operator_cost()->set_is_parameter_involve(is_parameter_involve_); + operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); + // Calculating 'output_in_memory' + operator_cost()->CalculateOutputInMemory(); + // Calculating 'inputs_in_memory' + std::map input_in_memory; + for (auto &p_edge : prev_edges) { + auto input_index = p_edge->next_op_input_index(); + auto is_in_mem = p_edge->prev_operator()->operator_cost()->is_output_in_memory(); + (void)input_in_memory.emplace(std::make_pair(input_index, is_in_mem)); + } + operator_cost()->CalculateInputsInMemory(input_in_memory); + + return is_output_parameter_involve_; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gamma_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gamma_info.h new file mode 100644 index 00000000000..c4cc51cae34 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gamma_info.h @@ -0,0 +1,64 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GAMMA_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GAMMA_INFO_H_ + +#include +#include +#include + +#include "utils/hash_map.h" +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class GammaInfo : public OperatorInfo { + public: + GammaInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~GammaInfo() override = default; + + std::vector GenerateOpStrategies(int64_t) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); } + void UpdateShape(const CNodePtr &cnode); + void ReplaceNodeInputOrAttrs() override; + void ReComputeBatchSplitFlagList() override; + int64_t ComputeOpAndPrevEdgeParameterInvolved() override; + + protected: + Status InferAttrs() override; + Status GetAttrs() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + + private: + void ResetInputsShape(); + bool IsNotSplittableStrategy(const Dimensions &strategy); + + int64_t seed_ = 0; + int64_t seed2_ = 0; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GAMMA_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/kldiv_loss_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/kldiv_loss_info.cc new file mode 100644 index 00000000000..7431ae97dec --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/kldiv_loss_info.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "frontend/parallel/ops_info/kldiv_loss_info.h" +#include "frontend/parallel/graph_util/generate_graph.h" + +namespace mindspore { +namespace parallel { +Status KLDivLossInfo::GetAttrs() { + reduction_ = GetStringAttr(REDUCTION); + return SUCCESS; +} + +Status KLDivLossInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + return FAILED; + } + + auto strategies = strategy->GetInputDim(); + if (strategies[0] != strategies[1]) { + MS_LOG(ERROR) << name_ << ": The strategy of 'logits' and 'labels' must be the same, but got strategy " + << StrategyToString(strategies); + return FAILED; + } + return SUCCESS; +} + +Status KLDivLossInfo::InferDevMatrixShape() { + dev_matrix_shape_.clear(); + auto strategies = strategy_->GetInputDim(); + auto logits_strategy = strategies.at(0); + dev_matrix_shape_ = logits_strategy; + + MS_LOG(INFO) << name_ << ": dev matrix is " << ShapeToString(dev_matrix_shape_); + return SUCCESS; +} + +Status KLDivLossInfo::InferTensorMap() { + inputs_tensor_map_.clear(); + outputs_tensor_map_.clear(); + + TensorMap sub_tensor_map; + auto strategies = strategy_->GetInputDim(); + auto logits_strategy = strategies.at(0); + size_t size = logits_strategy.size(); + for (size_t i = 0; i < size; ++i) { + sub_tensor_map.push_back(size - i - 1); + } + + inputs_tensor_map_.push_back(sub_tensor_map); + inputs_tensor_map_.push_back(sub_tensor_map); + + if (reduction_ == ATTR_NONE) { + (void)outputs_tensor_map_.emplace_back(sub_tensor_map); + } else { + (void)outputs_tensor_map_.emplace_back(TensorMap()); + } + return SUCCESS; +} + +std::vector KLDivLossInfo::GenerateOpStrategies(int64_t stage_id) { + Shape splittable_input(inputs_shape_[0].size()); + std::iota(splittable_input.begin(), splittable_input.end(), 1); + Shapes splittable_inputs(inputs_shape_.size(), splittable_input); + std::vector sp_vector; + if (GenerateStrategiesForDependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": Generate strategies for dependent inputs() failed."; + } + if (sp_vector.empty()) { + MS_LOG(EXCEPTION) << name_ << ": No available strategy."; + } + return sp_vector; +} + +Status KLDivLossInfo::InferGroup() { + Shape group_create_map; + if (repeated_calc_num_ > 1) { + if (repeated_num_in_dev_matrix_right_) { + group_create_map.push_back(0); + } else { + group_create_map.push_back(SizeToLong(dev_matrix_shape_.size()) - 1); + } + } + + if (CreateGroupByTensorMap(group_create_map, &group_list_) != SUCCESS) { + ReportError(name_ + ": Create group failed."); + return FAILED; + } + return SUCCESS; +} + +Status KLDivLossInfo::InferForwardCommunication() { + if (reduction_ == ATTR_NONE) { + MS_LOG(DEBUG) << name_ << ": reduction is " << reduction_ << ", there is no need to append reduce op."; + return SUCCESS; + } + if (InferGroup() != SUCCESS) { + MS_LOG(ERROR) << "Infer group failed"; + return FAILED; + } + if (group_list_.empty()) { + MS_LOG(INFO) << name_ << ": Forward all reduce is not required"; + return SUCCESS; + } + + forward_op_.clear(); + if (reduction_ == MEAN) { + auto element_type = outputs_dtype_->cast()->element(); + forward_op_ = CreateReduceMeanForwardOp(group_list_, element_type); + } else { + (void)forward_op_.emplace_back(CreateAllReduceOp(REDUCE_OP_SUM, group_list_[0].name())); + } + MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list_[0].name(); + + return SUCCESS; +} + +void KLDivLossInfo::ReComputeBatchSplitFlagList() { split_flag_list_.assign(inputs_shape_.size(), true); } +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/kldiv_loss_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/kldiv_loss_info.h new file mode 100644 index 00000000000..560bd96882f --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/kldiv_loss_info.h @@ -0,0 +1,57 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_KLDIV_LOSS_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_KLDIV_LOSS_INFO_H_ + +#include +#include +#include + +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +class KLDivLossInfo : public OperatorInfo { + public: + KLDivLossInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~KLDivLossInfo() override = default; + + std::vector GenerateOpStrategies(int64_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); } + void ReComputeBatchSplitFlagList() override; + + protected: + Status GetAttrs() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferForwardCommunication() override; + + private: + Status InferGroup(); + + std::string reduction_; + std::vector group_list_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_KLDIV_LOSS_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/lin_space_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/lin_space_info.cc new file mode 100644 index 00000000000..abdcb64931d --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/lin_space_info.cc @@ -0,0 +1,176 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "frontend/parallel/ops_info/lin_space_info.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "frontend/parallel/graph_util/generate_graph.h" + +namespace mindspore { +namespace parallel { +Status LinSpaceInfo::GetAttrs() { + auto output_0_shape = outputs_shape_.at(0); + output_size_ = output_0_shape.at(0); + return SUCCESS; +} + +Status LinSpaceInfo::CheckStrategy(const StrategyPtr &strategy) { + auto strategies = strategy->GetInputDim(); + if (strategies.size() != 1 && strategies[0].size() != 1) { + MS_LOG(ERROR) << name_ << ": The shape of input_strategy must be [1, 1], but got strategy " + << StrategyToString(strategies); + return FAILED; + } + + auto split_num = strategies[0][0]; + if (output_size_ % split_num != 0) { + MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategies) << ", output size is " + << output_size_ << " cannot be divisible by strategy value " << split_num; + } + + if (stage_device_size_ % split_num != 0) { + MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategies) + << ", the device size in this stage is " << stage_device_size_ + << " cannot be divisible by the strategy value " << split_num; + return FAILED; + } + return SUCCESS; +} + +Status LinSpaceInfo::InferDevMatrixShape() { + dev_matrix_shape_.clear(); + auto strategies = strategy_->GetInputDim(); + auto split_strategy = strategies.at(0); + dev_matrix_shape_.push_back(split_strategy.at(0)); + MS_LOG(INFO) << name_ << ": The device matrix is " << ShapeToString(dev_matrix_shape_); + return SUCCESS; +} + +Status LinSpaceInfo::InferTensorMap() { + inputs_tensor_map_.clear(); + outputs_tensor_map_.clear(); + + inputs_tensor_map_.assign(2, TensorMap{}); + (void)outputs_tensor_map_.emplace_back(TensorMap{0}); + return SUCCESS; +} + +std::vector LinSpaceInfo::GenerateOpStrategies(int64_t stage_id) { + Shape split_input(1, 1); + Shape split_shape(1, output_size_); + Shapes splittable_inputs = {split_input}; + Shapes splittable_shapes = {split_shape}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, splittable_shapes, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed."; + } + if (sp_vector.empty()) { + MS_LOG(EXCEPTION) << name_ << ": No available strategy."; + } + return sp_vector; +} + +std::shared_ptr LinSpaceInfo::GenerateBatchStrategies() { + if (InferAttrs() != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed"; + } + + int64_t dev_num = g_device_manager->stage_device_num(); + Strategys strategies = {Dimensions{dev_num}}; + return std::make_shared(strategies); +} + +int64_t LinSpaceInfo::GetSplitNum() { + auto strategies = strategy_->GetInputDim(); + auto split_strategy = strategies.at(0); + auto split_num = split_strategy.at(0); + return split_num; +} + +ReplaceGraphPtr LinSpaceInfo::replace_graph(const CNodePtr &cnode) { + auto split_num = GetSplitNum(); + if (split_num > 1 && ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; + } + return replace_graph_; +} + +Status LinSpaceInfo::ComputeReplaceGraph(const CNodePtr &cnode) { + auto split_num = GetSplitNum(); + if (output_size_ % split_num != 0) { + return FAILED; + } + if (split_num == 1) { + MS_LOG(INFO) << name_ << ": split num is 1, no need to replace graph"; + return SUCCESS; + } + + GenerateGraph gen_g = GenerateGraph(attrs_); + if (gen_g.Init(cnode) != SUCCESS) { + MS_LOG(ERROR) << "GenerateGraph Init failed."; + return FAILED; + } + if (InferSliceId() != SUCCESS) { + MS_LOG(ERROR) << "Infer slice id failed."; + return FAILED; + } + + int64_t slice_output_size = output_size_ / split_num; + auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), gen_g.virtual_input_node()}); + auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), sub}); + AnfNodePtr interval = nullptr; + if (output_size_ == 2) { + interval = sub; + } else { + auto interval_divsor = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(output_size_ - 2), dtype}); + interval = gen_g.PushBack({gen_g.NewOpInst(DIV), sub, interval_divsor}); + } + + // new_start = start + slice_id * slice_output_size * interval + // new_end = new_start + (slice_output_size - 1) * interval + // new_x = slice_output_size + auto offset_size = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(slice_id_ * slice_output_size), dtype}); + auto start_offset = gen_g.PushBack({gen_g.NewOpInst(MUL), interval, offset_size}); + auto new_start = gen_g.PushBack({gen_g.NewOpInst(ADD), gen_g.virtual_input_node(), start_offset}); + auto end_start_offset_size = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(slice_output_size - 1), dtype}); + auto start_end_offset = gen_g.PushBack({gen_g.NewOpInst(MUL), interval, end_start_offset_size}); + auto new_end = gen_g.PushBack({gen_g.NewOpInst(ADD), new_start, start_end_offset}); + auto lin_space = gen_g.PushBack({gen_g.NewOpInst(LIN_SPACE), new_start, new_end, CreatInt64Imm(slice_output_size)}); + + std::vector> inputs_nodes = {std::make_pair(sub, 1), std::make_pair(sub, 2), + std::make_pair(new_start, 1)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(inputs_nodes, lin_space)); + return SUCCESS; +} + +Status LinSpaceInfo::InferSliceId() { + CheckGlobalDeviceManager(); + int64_t rank = g_device_manager->rank_index_in_stage(); + slice_id_ = rank; + if (repeated_calc_num_ > 1) { + if (repeated_num_in_dev_matrix_right_) { + slice_id_ /= dev_matrix_shape_.back(); + } else { + slice_id_ %= dev_matrix_shape_.front(); + } + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/lin_space_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/lin_space_info.h new file mode 100644 index 00000000000..19a1ac22608 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/lin_space_info.h @@ -0,0 +1,62 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_LIN_SPACE_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_LIN_SPACE_INFO_H_ + +#include +#include +#include + +#include "utils/hash_map.h" +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class LinSpaceInfo : public OperatorInfo { + public: + LinSpaceInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~LinSpaceInfo() override = default; + + Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); } + std::vector GenerateOpStrategies(int64_t stage_id) override; + std::shared_ptr GenerateBatchStrategies() override; + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; + + protected: + Status GetAttrs() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferForwardCommunication() override { return SUCCESS; } + + private: + Status InferSliceId(); + Status ComputeReplaceGraph(const CNodePtr &cnode); + int64_t GetSplitNum(); + + int64_t slice_id_ = 0; + int64_t output_size_ = 0; +}; +} // namespace parallel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index 95992c660f6..71964ee3fa8 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -2112,5 +2112,29 @@ AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple auto tensor_node = NewValueNode(tensor_ptr); return tensor_node->cast(); } + +ForwardOp CreateReduceMeanForwardOp(const std::vector &forward_group, const TypePtr &dtype) { + // Create AllReduceSum op + Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name()); + std::string group_name = forward_group[0].name(); + MS_LOG(INFO) << "The group of forward all reduce is " << group_name; + + // Create RealDiv op + OperatorName operator1_name = REAL_DIV; + std::vector device_list = forward_group[0].GetDevicesList(); + auto divisor = static_cast(device_list.size()); + mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(divisor, dtype); + ValuePtr op1_param_value = MakeValue(tensor_ptr); + Attr op1_param = std::make_pair("divisor", op1_param_value); + OperatorParams operator1_params = {std::make_pair(op1_param, 2)}; + OperatorAttrs operator1_attrs; + OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params); + Operator op1 = std::make_pair(operator1_name, operator1_args); + ForwardOp forward_op = {op0, op1}; + + std::string dtype_name = dtype->ToString(); + MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name; + return forward_op; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index 3313f570a39..ff5c1f91ed9 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -112,7 +112,7 @@ class OperatorInfo { // In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated // by the output Status CalculateMemoryCostForInference(); - int64_t ComputeOpAndPrevEdgeParameterInvolved(); + virtual int64_t ComputeOpAndPrevEdgeParameterInvolved(); ForwardOp forward_op() const { return forward_op_; } ForwardOp replace_op() const { return replace_op_; } @@ -227,7 +227,7 @@ class OperatorInfo { void SetRepeatedCalcDevMatrix(); void ResetTensorMapIfRepeatedCalc(); Status CreateGroupByDim(size_t axis, std::vector *group); - Status InferAttrs(); + virtual Status InferAttrs(); void ResetQueueMember(); Status InitWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy); Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy); @@ -374,6 +374,8 @@ ValuePtr MakeListValue(const std::vector &v); ValuePtr MakeTupleListValue(const Shapes &v); AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector &value_tuple); AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple); + +ForwardOp CreateReduceMeanForwardOp(const std::vector &forward_group, const TypePtr &dtype); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h index 1f0c8ab28cc..42ead8c15e3 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h @@ -70,5 +70,8 @@ #include "frontend/parallel/ops_info/addn_info.h" #include "frontend/parallel/ops_info/inplace_add_info.h" #include "frontend/parallel/ops_info/cdist_info.h" +#include "frontend/parallel/ops_info/gamma_info.h" +#include "frontend/parallel/ops_info/kldiv_loss_info.h" +#include "frontend/parallel/ops_info/lin_space_info.h" #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index f7e3205c343..04fc4acff89 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -260,6 +260,9 @@ constexpr char POOLED_WIDTH[] = "pooled_width"; constexpr char SPATIAL_SCALE[] = "spatial_scale"; constexpr char SAMPLE_NUM[] = "sample_num"; constexpr char ROI_END_MODE[] = "roi_end_mode"; +constexpr char REDUCTION[] = "reduction"; +constexpr char MEAN[] = "mean"; +constexpr char ATTR_NONE[] = "none"; // Operator constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; @@ -477,6 +480,12 @@ constexpr char INPLACE_SUB[] = "InplaceSub"; constexpr char CDIST[] = "Cdist"; constexpr char L2_LOSS[] = "L2Loss"; constexpr char LERP[] = "Lerp"; +constexpr char SQUARED_DIFFERENCE[] = "SquaredDifference"; +constexpr char ERFINV[] = "Erfinv"; +constexpr char SPLITV[] = "SplitV"; +constexpr char GAMMA[] = "Gamma"; +constexpr char KLDIV_LOSS[] = "KLDivLoss"; +constexpr char LIN_SPACE[] = "LinSpace"; // pipeline constexpr size_t PIPELINE_FUSTION_OFFSET = 100; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc index d0f9148c2f4..97dd0f7afe9 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc @@ -201,30 +201,6 @@ Status ReduceMethod::InferForwardCommunication() { return SUCCESS; } -ForwardOp CreateReduceMeanForwardOp(const std::vector &forward_group, const TypePtr &dtype) { - // Create AllReduceSum op - Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name()); - std::string group_name = forward_group[0].name(); - MS_LOG(INFO) << "The group of forward all reduce is " << group_name; - - // Create RealDiv op - OperatorName operator1_name = REAL_DIV; - std::vector device_list = forward_group[0].GetDevicesList(); - auto divisor = static_cast(device_list.size()); - mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(divisor, dtype); - ValuePtr op1_param_value = MakeValue(tensor_ptr); - Attr op1_param = std::make_pair("divisor", op1_param_value); - OperatorParams operator1_params = {std::make_pair(op1_param, 2)}; - OperatorAttrs operator1_attrs; - OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params); - Operator op1 = std::make_pair(operator1_name, operator1_args); - ForwardOp forward_op = {op0, op1}; - - std::string dtype_name = dtype->ToString(); - MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name; - return forward_op; -} - Status ReduceMeanInfo::InferForwardCommunication() { auto strategies = strategy_->GetInputDim(); Dimensions stra = strategies.at(0); @@ -801,7 +777,7 @@ Status SquareSumAllInfo::InferGroup() { // if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix, // it need to handle the first dimension of map. if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) { - group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); + group_creat_map.push_back(SizeToLong(dev_matrix_shape_.size() - size_t(1))); } for (size_t index = 0; index < size; ++index) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/split_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/split_info.cc index 0cf99d2f321..ee1e7a91f55 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/split_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/split_info.cc @@ -30,9 +30,9 @@ namespace mindspore { namespace parallel { Status SplitInfo::GetAttrs() { int64_t axis = 0; - int64_t output_num = 0; - auto axis_iter = attrs_.find(AXIS); + std::string attr_axis_name = GetSplitAxisAttrName(); + auto axis_iter = attrs_.find(attr_axis_name); if (axis_iter != attrs_.end()) { MS_EXCEPTION_IF_NULL(axis_iter->second); if (axis_iter->second->isa()) { @@ -56,21 +56,6 @@ Status SplitInfo::GetAttrs() { } axis_ = LongToSize(axis); - auto output_num_iter = attrs_.find(OUTPUT_NUM); - if (output_num_iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(output_num_iter->second); - if (output_num_iter->second->isa()) { - output_num = output_num_iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << ": The value of output_num is not int"; - return FAILED; - } - } else { - MS_LOG(ERROR) << name_ << ": Can not find the output_num attr"; - return FAILED; - } - output_num_ = LongToSize(output_num); - return SUCCESS; } @@ -152,6 +137,9 @@ std::vector SplitInfo::GenerateOpStrategies(int64_t stage_id) { if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) { MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed"; } + if (sp_vector.empty()) { + MS_LOG(EXCEPTION) << name_ << ": No available strategy"; + } return sp_vector; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/split_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/split_info.h index 3bb21eba95e..58f734d4294 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/split_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/split_info.h @@ -46,10 +46,21 @@ class SplitInfo : public OperatorInfo { Status InferDevMatrixShape() override; Status InferTensorMap() override; Status InferAsLossDivisor() override; + virtual std::string GetSplitAxisAttrName() const { return AXIS; } private: size_t axis_ = 0; - size_t output_num_ = 0; +}; + +class SplitVInfo : public SplitInfo { + public: + SplitVInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : SplitInfo(operator_name, inputs_shape, outputs_shape, attrs) {} + ~SplitVInfo() override = default; + + protected: + std::string GetSplitAxisAttrName() const override { return SPLIT_DIM; }; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/uniform_real_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/uniform_real_info.cc index 175788ce83c..9b7f2dafd1a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/uniform_real_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/uniform_real_info.cc @@ -28,6 +28,19 @@ namespace mindspore { namespace parallel { +Status UniformRealInfo::InferAttrs() { + if (infer_attrs_completed_) { + return SUCCESS; + } + + if (GetAttrs() != SUCCESS) { + return FAILED; + } + ResetInputsShape(); + infer_attrs_completed_ = true; + return SUCCESS; +} + Status UniformRealInfo::GetAttrs() { seed_ = GetIntAttr(SEED); if (seed_ < 0) { @@ -39,10 +52,6 @@ Status UniformRealInfo::GetAttrs() { MS_LOG(ERROR) << name_ << ": Seed2 must be greater or equal to zero, bug got " << seed2_; return FAILED; } - ValueTuplePtr shape_value = input_value_[0]->cast(); - MS_EXCEPTION_IF_NULL(shape_value); - inputs_shape_.push_back(GetValue(shape_value)); - return SUCCESS; } @@ -97,7 +106,10 @@ std::vector UniformRealInfo::GenerateOpStrategies(int64_t stage_id) std::vector sp_vector; if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed."; + MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed."; + } + if (sp_vector.empty()) { + MS_LOG(EXCEPTION) << name_ << ": No available strategy."; } return sp_vector; } @@ -120,23 +132,29 @@ void UniformRealInfo::UpdateShape(const CNodePtr &cnode) { } void UniformRealInfo::ReplaceNodeInputOrAttrs() { + // Replace input 'shape' to slice shape auto cnode = cnode_; - int64_t split_num = 1; UpdateShape(cnode); - std::vector stra = strategy_->GetInputDim(); - for (size_t i = 0; i < stra[0].size(); i++) { - split_num *= stra[0][i]; - } - int64_t device_num = stage_device_size_; - int64_t rank_id = g_device_manager->rank_index_in_stage(); - if (device_num != split_num) { - auto split_group_num = device_num / split_num; - int64_t seed_bias = rank_id / split_group_num; - auto prim = GetValueNode(cnode->input(0)); - prim->set_attr(SEED, MakeValue(seed_ + seed_bias)); - prim->set_attr(SEED2, MakeValue(seed2_ + seed_bias)); + // Update seed according rank_id + int64_t rank_id = g_device_manager->rank_index_in_stage(); + int64_t seed_bias; + if (repeated_num_in_dev_matrix_right_) { + seed_bias = rank_id / repeated_calc_num_; + } else { + int64_t device_num = stage_device_size_; + seed_bias = rank_id % (device_num / repeated_calc_num_); } + + auto prim = GetValueNode(cnode->input(0)); + prim->set_attr(SEED, MakeValue(seed_ + seed_bias)); + prim->set_attr(SEED2, MakeValue(seed2_ + seed_bias)); +} + +void UniformRealInfo::ResetInputsShape() { + ValueTuplePtr shape_value = input_value_[0]->cast(); + MS_EXCEPTION_IF_NULL(shape_value); + inputs_shape_.push_back(GetValue(shape_value)); } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/uniform_real_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/uniform_real_info.h index 75c5d73e2f7..f1a268dca85 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/uniform_real_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/uniform_real_info.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORMREAL_INFO_H_ -#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORMREAL_INFO_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORM_REAL_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORM_REAL_INFO_H_ #include #include @@ -33,7 +33,7 @@ class UniformRealInfo : public OperatorInfo { public: UniformRealInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} ~UniformRealInfo() override = default; std::vector GenerateOpStrategies(int64_t) override; @@ -42,6 +42,7 @@ class UniformRealInfo : public OperatorInfo { void ReplaceNodeInputOrAttrs() override; protected: + Status InferAttrs() override; Status GetAttrs() override; Status CheckStrategy(const StrategyPtr &strategy) override; Status InferForwardCommunication() override { return SUCCESS; } @@ -49,11 +50,12 @@ class UniformRealInfo : public OperatorInfo { Status InferTensorMap() override; private: + void ResetInputsShape(); + int64_t seed_ = 0; int64_t seed2_ = 0; }; - } // namespace parallel } // namespace mindspore -#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORMREAL_INFO_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORM_REAL_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc index e293d48b0a4..6ebf350b41b 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc @@ -71,8 +71,9 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { max_size_strategy_dim_ = i; } } - if (std::find(stra[max_size_strategy_dim_].begin(), stra[max_size_strategy_dim_].end(), shard_num_) == - stra[max_size_strategy_dim_].end()) { + if (!stra[max_size_strategy_dim_].empty() && + std::find(stra[max_size_strategy_dim_].begin(), stra[max_size_strategy_dim_].end(), shard_num_) == + stra[max_size_strategy_dim_].end()) { MS_LOG(ERROR) << name_ << ": For each dataset input, the shard strategy can be not shard, " "or shard in one dim with the same shard size between each input." @@ -96,7 +97,7 @@ Status VirtualDatasetInfo::InferForwardCommunication() { return SUCCESS; } Status VirtualDatasetInfo::InferTensorMap() { auto dev_mat_origin = strategy_->GetInputDim()[max_size_strategy_dim_]; auto slice_dim_iter = std::find(dev_mat_origin.begin(), dev_mat_origin.end(), shard_num_); - if (slice_dim_iter == dev_mat_origin.end()) { + if (!dev_mat_origin.empty() && slice_dim_iter == dev_mat_origin.end()) { MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim."; return FAILED; } @@ -108,8 +109,7 @@ Status VirtualDatasetInfo::InferTensorMap() { if (dim == 1) { tensor_map_index.push_back(MAP_NONE); } else if (dim == shard_num_) { - if (repeated_num_in_dev_matrix_right_ && dev_matrix_shape_.size() != dev_mat_origin.size() && - is_auto_parallel_) { + if (repeated_calc_num_ > 1 && repeated_num_in_dev_matrix_right_ && is_auto_parallel_) { tensor_map_index.push_back(dev_mat_origin.size() - slice_dim); } else { tensor_map_index.push_back(dev_mat_origin.size() - 1 - slice_dim); @@ -176,8 +176,10 @@ Status VirtualDatasetInfo::GenerateStrategies(int64_t stage_id) { } for (auto &shape : inputs_shape_) { Shape temp; - temp.emplace_back(SizeToLong(total_dev_num)); - (void)temp.insert(temp.end(), shape.size() - 1, 1); + if (!shape.empty()) { + temp.emplace_back(SizeToLong(total_dev_num)); + (void)temp.insert(temp.end(), shape.size() - 1, 1); + } strategy.push_back(temp); } } diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 85328e00881..a22a5582a6f 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -179,7 +179,7 @@ bool IsSplittableOperator(const std::string &op_name) { RESIZE_NEAREST_NEIGHBOR, CUM_SUM, FAST_GELU, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH, SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND, BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, L2_LOSS, LERP, ADDN, - CDIST}; + CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE}; // clang-format on auto iter = splittable_op.find(op_name); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 5c0ab24665f..294819e48d5 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1809,18 +1809,19 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { std::vector elements; for (size_t i = 0; i < shape_list[0].size(); i++) { if (shape_list[0][i].empty()) { - MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero"; + (void)elements.emplace_back(MakeValue(Dimensions())); + continue; } Dimensions input_strategy; - if (!shape_list[0][i].empty() && shape_list[0][i][0] % dev_num == 0) { + if (shape_list[0][i][0] % dev_num == 0) { input_strategy.push_back(dev_num); - } else if (!shape_list[0][i].empty()) { + } else { input_strategy.push_back(1); } for (size_t j = 1; j < shape_list[0][i].size(); j++) { input_strategy.push_back(1); } - elements.push_back(MakeValue(input_strategy)); + (void)elements.emplace_back(MakeValue(input_strategy)); } ValueTuplePtr strategy = std::make_shared(elements); attrs_temp[IN_STRATEGY] = strategy; diff --git a/tests/ut/python/parallel/test_arithmetic.py b/tests/ut/python/parallel/test_arithmetic.py index 46e228555a5..7d624fca9c2 100644 --- a/tests/ut/python/parallel/test_arithmetic.py +++ b/tests/ut/python/parallel/test_arithmetic.py @@ -1110,3 +1110,89 @@ def test_matmul_xlogy_broadcast(): y = Tensor(np.ones([32, 1]), dtype=ms.float32) b = Tensor(np.ones([1, 64]), dtype=ms.float32) compile_net(net, x, y, b) + + +def test_matmul_squared_difference_broadcast(): + """ + Feature: distribute operator SquaredDifference in auto parallel. + Description: mul-SquaredDifference net with strategy in semi auto parallel. + Expectation: compile done without error. + """ + class Net(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.matmul = P.MatMul().shard(strategy1) + self.squared_difference = P.SquaredDifference().shard(strategy2) + + def construct(self, x, y, b): + out = self.matmul(x, y) + out = self.squared_difference(out, b) + return out + + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((2, 4), (4, 1)) + strategy2 = ((4, 1), (1, 2)) + net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 1]), dtype=ms.float32) + b = Tensor(np.ones([1, 64]), dtype=ms.float32) + compile_net(net, x, y, b) + + +def test_matmul_masked_fill_broadcast_with_value_float(): + """ + Feature: distribute operator MaskedFill in auto parallel. + Description: mul-MaskedFill net with strategy in semi auto parallel. + Expectation: compile done without error. + """ + class Net(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.matmul = P.MatMul().shard(strategy1) + self.masked_fill = P.MaskedFill().shard(strategy2) + self.value = 1.0 + + def construct(self, x, y, b): + out = self.matmul(x, y) + out = self.masked_fill(out, b, self.value) + return out + + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((2, 4), (4, 1)) + strategy2 = ((4, 1), (1, 2)) + net = Net(strategy1, strategy2) + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 1]), dtype=ms.float32) + b = Tensor(np.ones([1, 64]), dtype=ms.bool_) + compile_net(net, x, y, b) + + +def test_matmul_masked_fill_broadcast_with_value_tensor(): + """ + Feature: distribute operator MaskedFill in auto parallel. + Description: mul-MaskedFill net with strategy in semi auto parallel. + Expectation: compile done without error. + """ + class Net(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.matmul = P.MatMul().shard(strategy1) + self.masked_fill = P.MaskedFill().shard(strategy2) + self.value = Tensor(1.0, ms.float32) + + def construct(self, x, y, b): + out = self.matmul(x, y) + out = self.masked_fill(out, b, self.value) + return out + + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((2, 4), (4, 1)) + strategy2 = ((4, 1), (1, 2), ()) + net = Net(strategy1, strategy2) + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 1]), dtype=ms.float32) + b = Tensor(np.ones([1, 64]), dtype=ms.bool_) + compile_net(net, x, y, b) diff --git a/tests/ut/python/parallel/test_element_wise_function.py b/tests/ut/python/parallel/test_element_wise_function.py index 7c5840300e9..d0d2a378657 100644 --- a/tests/ut/python/parallel/test_element_wise_function.py +++ b/tests/ut/python/parallel/test_element_wise_function.py @@ -1173,3 +1173,34 @@ def test_matmul_soft_shrink(): y = Tensor(np.random.uniform(-5, 5, size=(32, 64)), dtype=ms.float32) b = Tensor(np.random.uniform(-5, 5, size=(64, 64)), dtype=ms.float32) compile_net(net, x, y, b) + + +def test_matmul_erfinv(): + """ + Feature: distribute operator Erfinv in auto parallel. + Description: matmul-squared_difference-matmul net with strategy in semi auto parallel. + Expectation: compile done without error. + """ + class Net(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.matmul = P.MatMul().shard(strategy1) + self.erfinv = P.Erfinv().shard(strategy2) + self.matmul2 = P.MatMul().shard(strategy1) + + def construct(self, x, y, b): + out = self.matmul(x, y) + out = self.erfinv(out) + out = self.matmul2(out, b) + return out + + context.set_auto_parallel_context(device_num=8, global_rank=0) + strategy1 = ((2, 2), (2, 2)) + strategy2 = ((4, 2),) + net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + + x = Tensor(np.random.uniform(-5, 5, size=(128, 32)), dtype=ms.float32) + y = Tensor(np.random.uniform(-5, 5, size=(32, 64)), dtype=ms.float32) + b = Tensor(np.random.uniform(-5, 5, size=(64, 64)), dtype=ms.float32) + compile_net(net, x, y, b) diff --git a/tests/ut/python/parallel/test_gamma.py b/tests/ut/python/parallel/test_gamma.py new file mode 100644 index 00000000000..d2c6cec4741 --- /dev/null +++ b/tests/ut/python/parallel/test_gamma.py @@ -0,0 +1,94 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +import mindspore.common.dtype as mstype +from mindspore import Tensor, context +from mindspore.nn import Cell +from mindspore.ops import operations as P + +from parallel.utils.utils import ParallelValidator, compile_net + +SEED_ = 1 +SEED2_ = 1 +alpha_ = Tensor(np.array([1.0]), mstype.float32) +beta_ = Tensor(np.array([1.0]), mstype.float32) + + +class Net(Cell): + def __init__(self, seed, seed2, strategy=None): + super(Net, self).__init__() + self.gamma = P.Gamma(seed, seed2).shard(strategy) + + def construct(self, shape, alpha, beta): + out = self.gamma(shape, alpha, beta) + return out + + +def test_gamma_auto_parallel(): + """ + Features: test Gamma auto parallel + Description: auto parallel + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True) + net = Net(SEED_, SEED2_) + shape = (4, 4, 4) + compile_net(net, shape, alpha_, beta_) + + +def test_gamma_data_parallel(): + """ + Features: test Gamma data parallel + Description: data parallel + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1) + net = Net(SEED_, SEED2_) + shape = (8, 8) + phase = compile_net(net, shape, alpha_, beta_) + + validator = ParallelValidator(net, phase) + assert validator.check_node_attrs("Gamma-0", {"seed": 2, "seed2": 2}) + + +def test_gamma_model_parallel(): + """ + Features: test Gamma model parallel + Description: model parallel + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5) + shape = (8, 8) + strategy = ((2, 2), (1,), (1,)) + net = Net(SEED_, SEED2_, strategy) + phase = compile_net(net, shape, alpha_, beta_) + validator = ParallelValidator(net, phase) + assert validator.check_node_attrs("Gamma-0", {"seed": 3, "seed2": 3}) + + +def test_gamma_strategy_error(): + """ + Features:test Gamma strategy error + Description: invalid strategy + Expectation: Raise RuntimeError + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + shape = (8, 8) + strategy = ((2, 2), (2,), (1,)) + net = Net(SEED_, SEED2_, strategy) + with pytest.raises(RuntimeError): + compile_net(net, shape, alpha_, beta_) diff --git a/tests/ut/python/parallel/test_kldiv_loss.py b/tests/ut/python/parallel/test_kldiv_loss.py new file mode 100644 index 00000000000..42ef7b08bcb --- /dev/null +++ b/tests/ut/python/parallel/test_kldiv_loss.py @@ -0,0 +1,127 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore.common.dtype as mstype +from mindspore import Tensor, context +from mindspore.nn import Cell +from mindspore.ops import operations as P + +from parallel.utils.utils import ParallelValidator, compile_net + +logits_ = Tensor(np.random.uniform(0, 1, [8, 8]), mstype.float32) +labels_ = Tensor(np.random.randint(0, 10, [8, 8]), mstype.float32) + + +class Net(Cell): + def __init__(self, reduction, strategy=None): + super(Net, self).__init__() + self.kldiv_loss = P.KLDivLoss(reduction).shard(strategy) + + def construct(self, logits, labels): + out = self.kldiv_loss(logits, labels) + return out + + +def test_kldiv_loss_mean_auto_parallel(): + """ + Features: test KLDivLoss auto parallel + Description: auto parallel, reduction is 'mean' + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True) + reduction = 'mean' + net = Net(reduction) + compile_net(net, logits_, labels_) + + +def test_kldiv_loss_none_auto_parallel(): + """ + Features: test KLDivLoss auto parallel + Description: auto parallel, reduction is 'none' + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True) + reduction = 'none' + net = Net(reduction) + compile_net(net, logits_, labels_) + + +def test_kldiv_loss_sum_auto_parallel(): + """ + Features: test KLDivLoss auto parallel + Description: auto parallel, reduction is 'sum' + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True) + reduction = 'sum' + net = Net(reduction) + compile_net(net, logits_, labels_) + + +def test_kldiv_loss_mean_data_parallel(): + """ + Features: test KLDivLoss data parallel + Description: data parallel, reduction is 'mean' + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1) + reduction = 'mean' + net = Net(reduction) + phase = compile_net(net, logits_, labels_) + validator = ParallelValidator(net, phase) + assert validator.check_node_inputs('AllReduce-0', ['KLDivLoss-0']) + assert validator.check_node_attrs('AllReduce-0', {'op': 'sum'}) + + +def test_kldiv_loss_none_data_parallel(): + """ + Features: test KLDivLoss data parallel + Description: data parallel, reduction is 'none' + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1) + reduction = 'none' + net = Net(reduction) + compile_net(net, logits_, labels_) + + +def test_kldiv_loss_none_model_parallel(): + """ + Features: test KLDivLoss model parallel + Description: model parallel, reduction is 'none' + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5) + reduction = 'none' + strategy = ((2, 2), (2, 2)) + net = Net(reduction, strategy) + compile_net(net, logits_, labels_) + + +def test_kldiv_loss_mean_model_parallel(): + """ + Features: test KLDivLoss model parallel + Description: model parallel, reduction is 'mean' + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5) + reduction = 'mean' + strategy = ((4, 2), (4, 2)) + net = Net(reduction, strategy) + phase = compile_net(net, logits_, labels_) + validator = ParallelValidator(net, phase) + assert validator.check_node_inputs('AllReduce-0', ['KLDivLoss-0']) + assert validator.check_node_attrs('AllReduce-0', {'op': 'sum'}) diff --git a/tests/ut/python/parallel/test_l2_loss.py b/tests/ut/python/parallel/test_l2_loss.py index 0aa2a05434b..fcc54b75fad 100644 --- a/tests/ut/python/parallel/test_l2_loss.py +++ b/tests/ut/python/parallel/test_l2_loss.py @@ -18,7 +18,7 @@ from mindspore import Tensor, context from mindspore.nn import Cell from mindspore.ops import operations as P -from parallel.utils.utils import compile_net +from parallel.utils.utils import ParallelValidator, compile_net x_ = Tensor(np.random.normal(size=[32, 8, 8]).astype(np.float32)) @@ -52,4 +52,21 @@ def test_l2_loss_model_parallel(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy = ((2, 2, 2),) net = Net(strategy) - compile_net(net, x_) + phase = compile_net(net, x_) + validator = ParallelValidator(net, phase) + assert validator.check_node_inputs('AllReduce-0', ['L2Loss-0']) + assert validator.check_node_attrs('AllReduce-0', {'op': 'sum'}) + + +def test_l2_loss_data_parallel(): + """ + Feature: test L2Loss data parallel + Description: data parallel + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + net = Net() + phase = compile_net(net, x_) + validator = ParallelValidator(net, phase) + assert validator.check_node_inputs('AllReduce-0', ['L2Loss-0']) + assert validator.check_node_attrs('AllReduce-0', {'op': 'sum'}) diff --git a/tests/ut/python/parallel/test_lin_space.py b/tests/ut/python/parallel/test_lin_space.py new file mode 100644 index 00000000000..bfe8703cd8d --- /dev/null +++ b/tests/ut/python/parallel/test_lin_space.py @@ -0,0 +1,111 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import mindspore.common.dtype as mstype +from mindspore import Tensor, context +from mindspore.nn import Cell +from mindspore.ops import operations as P + +from parallel.utils.utils import ParallelValidator, compile_net + + +class Net(Cell): + def __init__(self, strategy=None): + super(Net, self).__init__() + self.lin_space = P.LinSpace().shard(strategy) + + def construct(self, start, end, x): + return self.lin_space(start, end, x) + + +def test_lin_space_loss_auto_parallel(): + """ + Feature: test LinSpace auto parallel + Description: auto parallel + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + start = Tensor(1, mstype.float32) + end = Tensor(10, mstype.float32) + x = 8 + net = Net() + compile_net(net, start, end, x) + + +def test_lin_space_parallel_with_repeated_cal(): + """ + Feature: test LinSpace parallel + Description: parallel with repeated_cal + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + start = Tensor(1, mstype.float32) + end = Tensor(10, mstype.float32) + x = 8 + strategy = ((4,),) + net = Net(strategy) + phase = compile_net(net, start, end, x) + validator = ParallelValidator(net, phase) + assert validator.check_node_inputs('LinSpace-0', ['Add-0', 'Add-1', 2]) + + +def test_lin_space_parallel_with_x_2(): + """ + Feature: test LinSpace parallel + Description: parallel with input x is 2 + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + start = Tensor(1, mstype.float32) + end = Tensor(10, mstype.float32) + x = 2 + strategy = ((2,),) + net = Net(strategy) + phase = compile_net(net, start, end, x) + validator = ParallelValidator(net, phase) + assert validator.check_node_inputs('LinSpace-0', ['Add-0', 'Add-1', 1]) + + +def test_lin_space_data_parallel(): + """ + Feature: test LinSpace data parallel + Description: data parallel + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + start = Tensor(1, mstype.float32) + end = Tensor(10, mstype.float32) + x = 8 + net = Net() + phase = compile_net(net, start, end, x) + validator = ParallelValidator(net, phase) + assert validator.check_node_inputs('LinSpace-0', ['Add-0', 'Add-1', 1]) + + +def test_lin_space_parallel_strategy_wrong(): + """ + Feature: test LinSpace parallel with invalid strategy + Description: invalid strategy + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + start = Tensor(1, mstype.float32) + end = Tensor(10, mstype.float32) + x = 6 + strategy = ((4,),) + net = Net(strategy) + with pytest.raises(RuntimeError): + compile_net(net, start, end, x) diff --git a/tests/ut/python/parallel/test_splitv.py b/tests/ut/python/parallel/test_splitv.py new file mode 100644 index 00000000000..7d109ac0707 --- /dev/null +++ b/tests/ut/python/parallel/test_splitv.py @@ -0,0 +1,118 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore as ms +import mindspore.context as context +from mindspore import Tensor, Parameter +import mindspore.nn as nn +from mindspore.ops import operations as P + +from parallel.utils.utils import compile_net + + +input_x_tensor_ = Tensor(np.ones([8, 8, 8]), ms.float32) +input_x_parameter_ = Parameter(Tensor(np.ones([8, 8, 8]), ms.float32), "input_x") +SIZE_SPLIT = [3, 3, 2] +NUM_SPLIT = 3 + + +class Net(nn.Cell): + def __init__(self, size_split, split_dim, num_split, strategy1=None, strategy2=None): + super(Net, self).__init__() + self.splitv = P.SplitV(size_split, split_dim, num_split).shard(strategy1) + self.mul = P.Mul().shard(strategy2) + self.weight = Parameter(np.array([1.0]), name="mul_weight") + + def construct(self, x): + out = self.splitv(x) + out = self.mul(out[0], self.weight) + return out + + +class NetWithParameter(nn.Cell): + def __init__(self, size_split, split_dim, num_split, strategy1=None, strategy2=None): + super(NetWithParameter, self).__init__() + self.splitv = P.SplitV(size_split, split_dim, num_split).shard(strategy1) + self.mul = P.Mul().shard(strategy2) + self.weight = input_x_parameter_ + + def construct(self, x): + out = self.splitv(self.weight) + out = self.mul(x, out[0]) + return out + + +def test_splitv_auto_parallel(): + """ + Feature: test SplitV auto parallel + Description: auto parallel + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net(SIZE_SPLIT, 0, NUM_SPLIT) + compile_net(net, input_x_tensor_) + + +def test_splitv_auto_parallel_with_parameter(): + """ + Feature: test SplitV auto parallel with parameter input + Description: auto parallel with parameter input + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = NetWithParameter(SIZE_SPLIT, 2, NUM_SPLIT) + x = Tensor(np.ones([8, 8, 3]), ms.float32) + compile_net(net, x) + + +def test_splitv_data_parallel(): + """ + Feature: test SplitV data parallel + Description: data parallel + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + net = Net(SIZE_SPLIT, 1, NUM_SPLIT) + compile_net(net, input_x_tensor_) + + +def test_splitv_model_parallel(): + """ + Feature: test SplitV model parallel + Description: model parallel + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 2, 2),) + strategy2 = ((1, 2, 2), (1,)) + net = Net(SIZE_SPLIT, 0, NUM_SPLIT, strategy1, strategy2) + compile_net(net, input_x_tensor_) + + +def test_splitv_strategy_error(): + """ + Feature: test SplitV parallel with invalid strategy + Description: config invalid strategy + Expectation: raise RuntimeError + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 2),) + strategy2 = ((8, 1, 1),) + net = Net(SIZE_SPLIT, 0, NUM_SPLIT, strategy1, strategy2) + with pytest.raises(RuntimeError): + compile_net(net, input_x_tensor_) + context.reset_auto_parallel_context() diff --git a/tests/ut/python/parallel/test_uniform_real.py b/tests/ut/python/parallel/test_uniform_real.py new file mode 100644 index 00000000000..9ce099f0355 --- /dev/null +++ b/tests/ut/python/parallel/test_uniform_real.py @@ -0,0 +1,74 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mindspore import context +from mindspore.nn import Cell +from mindspore.ops import operations as P + +from parallel.utils.utils import ParallelValidator, compile_net + +SEED_ = 1 +SEED2_ = 1 + + +class Net(Cell): + def __init__(self, seed, seed2, strategy=None): + super(Net, self).__init__() + self.uniform_real = P.UniformReal(seed, seed2).shard(strategy) + + def construct(self, shape): + out = self.uniform_real(shape) + return out + + +def test_uniform_real_auto_parallel(): + """ + Features: test UniformReal auto parallel + Description: auto parallel + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net(SEED_, SEED2_) + shape = (4, 4, 4) + compile_net(net, shape) + + +def test_uniform_real_data_parallel(): + """ + Features: test UniformReal data parallel + Description: data parallel + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1) + net = Net(SEED_, SEED2_) + shape = (8, 8) + phase = compile_net(net, shape) + + validator = ParallelValidator(net, phase) + assert validator.check_node_attrs("UniformReal-0", {"seed": 2, "seed2": 2}) + + +def test_uniform_real_model_parallel(): + """ + Features: test UniformReal model parallel + Description: model parallel + Expectation: compile success + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5) + shape = (8, 8) + strategy = ((2, 2),) + net = Net(SEED_, SEED2_, strategy) + phase = compile_net(net, shape) + validator = ParallelValidator(net, phase) + assert validator.check_node_attrs("UniformReal-0", {"seed": 3, "seed2": 3})