!31988 Implementation of SquaredDifferenceInfo, ErfinvInfo, MaskedFillInfo, SplitVInfo, GammaInfo, KLDivLossInfo and LinSpaceInfo.

Merge pull request !31988 from liuluobin/ops_impl
This commit is contained in:
i-robot 2022-03-29 08:08:32 +00:00 committed by Gitee
commit e897c98b1f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
35 changed files with 1629 additions and 86 deletions

View File

@ -192,6 +192,8 @@ using IOUCost = CastCost;
using RandomChoicWithMaskCost = CastCost;
using IsFiniteCost = CastCost;
using RintCost = CastCost;
using GammaCost = CastCost;
using LinSpaceCost = CastCost;
class SqrtCost : public CastCost {
public:
@ -247,6 +249,7 @@ using ErfcCost = ReLU6Cost;
using ActivationInfoCost = ReLU6Cost;
using SelectCost = ReLU6Cost;
using XlogyCost = ReLU6Cost;
using ErfinvCost = ReLU6Cost;
class TransposeCost : public CastCost {
public:
@ -634,6 +637,7 @@ using BitwiseOrCost = SubCost;
using BitwiseXorCost = SubCost;
using AddNCost = SubCost;
using InplaceAddCost = SubCost;
using MaskedFillCost = SubCost;
class MulCost : public SubCost {
public:
@ -646,6 +650,7 @@ class MulCost : public SubCost {
using MulNoNanCost = MulCost;
using GatherDCost = MulCost;
using LerpCost = MulCost;
using SquaredDifferenceCost = MulCost;
class DivCost : public SubCost {
public:
@ -777,6 +782,7 @@ using ReduceMethodCost = ReduceSumCost;
using ReduceProdCost = ReduceSumCost;
using SquareSumAllCost = ReduceSumCost;
using L2LossCost = ReduceSumCost;
using KLDivLossCost = ReduceSumCost;
class ReduceMeanCost : public ReduceSumCost {
public:

View File

@ -243,6 +243,13 @@ REGISTER(InplaceSubInfo);
REGISTER(CdistInfo);
REGISTER(L2LossInfo);
REGISTER(LerpInfo);
REGISTER(SquaredDifferenceInfo);
REGISTER(ErfinvInfo);
REGISTER(MaskedFillInfo);
REGISTER(SplitVInfo);
REGISTER(GammaInfo);
REGISTER(KLDivLossInfo);
REGISTER(LinSpaceInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -91,6 +91,19 @@ AnfNodePtr CreateInt32Tensor(int64_t value) {
return anf_node_ptr;
}
static mindspore::HashMap<float, AnfNodePtr> float_tensor_map = {};
AnfNodePtr CreateFP32Tensor(float value) {
auto it = float_tensor_map.find(value);
if (it != float_tensor_map.end()) {
return it->second;
}
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(value, kFloat32);
ValuePtr value_ptr = MakeValue(tensor_ptr);
auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr);
float_tensor_map[value] = anf_node_ptr;
return anf_node_ptr;
}
AnfNodePtr CreateTypeInt(int64_t nbits) {
ValuePtr value_ptr = MakeValue(std::make_shared<Int>(nbits));
return ValuePtrToAnfNodePtr(value_ptr);

View File

@ -42,6 +42,7 @@ AnfNodePtr CreateTypeFloat(int64_t nbits);
AnfNodePtr CreatInt64Imm(int64_t value);
AnfNodePtr CreateFP32Imm(float value);
AnfNodePtr CreateInt32Tensor(int64_t value);
AnfNodePtr CreateFP32Tensor(float value);
AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr);
AnfNodePtr CreateTuple(const std::vector<int64_t> &tuple);
std::string HashInstanceName(const std::string &name);

View File

@ -76,7 +76,16 @@ std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
} else {
inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
}
return std::vector<bool>(inputs_seq.size(), false);
size_t inputs_seq_tensor_size = inputs_seq.size();
for (const auto &inputs_seq_value : inputs_seq) {
auto tensor = inputs_seq_value->cast<tensor::TensorPtr>();
if (tensor == nullptr) {
MS_LOG(DEBUG) << "The value not is not a tensor.";
inputs_seq_tensor_size = 0;
break;
}
}
return std::vector<bool>(inputs_seq_tensor_size, false);
}
if ((node_inputs.size() == 2) &&
(AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
@ -168,7 +177,10 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
}
for (auto &ele : inputs_seq) {
auto tensor = ele->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
if (tensor == nullptr) {
inputs_type_len.clear();
return inputs_type_len;
}
inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype()));
}
return inputs_type_len;

View File

@ -28,6 +28,7 @@
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/redistribution_operator_infer.h"
#include "frontend/parallel/graph_util/generate_graph.h"
namespace mindspore {
namespace parallel {
@ -616,5 +617,33 @@ Status L2LossInfo::InferTensorMap() {
outputs_tensor_map_[0].clear();
return SUCCESS;
}
Status L2LossInfo::InferForwardCommunication() {
forward_op_.clear();
Shape group_create_map;
if (repeated_calc_num_ > 1) {
if (repeated_num_in_dev_matrix_right_) {
group_create_map.push_back(0);
} else {
group_create_map.push_back(SizeToLong(dev_matrix_shape_.size()) - 1);
}
}
std::vector<Group> group_list;
if (CreateGroupByTensorMap(group_create_map, &group_list) != SUCCESS) {
ReportError(name_ + ": Create group failed.");
return FAILED;
}
if (group_list.empty()) {
MS_LOG(INFO) << name_ << ": Forward all reduce is not required";
return SUCCESS;
}
Operator op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
forward_op_.push_back(op);
MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list[0].name();
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -374,13 +374,22 @@ class SoftShrinkInfo : public ActivationOther {
class L2LossInfo : public ActivationOther {
public:
L2LossInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &output_shape,
L2LossInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, output_shape, attrs, std::make_shared<L2LossCost>()) {}
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<L2LossCost>()) {}
~L2LossInfo() = default;
protected:
Status InferTensorMap() override;
Status InferForwardCommunication() override;
};
class ErfinvInfo : public ActivationOther {
public:
ErfinvInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ErfinvCost>()) {}
~ErfinvInfo() override = default;
};
} // namespace parallel
} // namespace mindspore

View File

@ -353,5 +353,39 @@ void LerpInfo::ReComputeBatchSplitFlagList() {
Shape expand_weight_shape = ExpandShape(expand_a_shape, inputs_shape_.at(2));
(expand_weight_shape.at(0) != 1) ? (split_flag_list_[2] = true) : (split_flag_list_[2] = false);
}
Status MaskedFillInfo::GetAttrs() {
input_size_ = inputs_shape_.size();
if (input_size_ != 2 && input_size_ != 3) {
MS_LOG(ERROR) << name_ << ": inputs_shape_.size() must be 2 or 3, but got size " << input_size_;
return FAILED;
}
return SUCCESS;
}
Status MaskedFillInfo::InferTensorMap() {
if (ArithmeticBase::InferTensorMap() != SUCCESS) {
return FAILED;
}
if (input_size_ == 3) {
// append a void tensor map for 0-dimensional tensor input 'value'
(void)inputs_tensor_map_.emplace_back(TensorMap());
}
return SUCCESS;
}
std::vector<StrategyPtr> MaskedFillInfo::GenerateOpStrategies(int64_t stage_id) {
auto sp_vector = ArithmeticBase::GenerateOpStrategies(stage_id);
if (input_size_ == 3) {
// append void strategy for input `value`
for (size_t i = 0; i < sp_vector.size(); ++i) {
auto strategies = sp_vector[i]->GetInputDim();
(void)strategies.emplace_back(Dimensions());
sp_vector[i] = std::make_shared<Strategy>(stage_id, strategies);
}
}
return sp_vector;
}
} // namespace parallel
} // namespace mindspore

View File

@ -263,6 +263,31 @@ class LerpInfo : public ArithmeticBase {
private:
size_t inputs_size_ = 0;
};
class SquaredDifferenceInfo : public ArithmeticBase {
public:
SquaredDifferenceInfo(const std::string &name, const Shapes &input_shape, const Shapes &output_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, input_shape, output_shape, attrs, std::make_shared<SquaredDifferenceCost>()) {}
~SquaredDifferenceInfo() override = default;
};
class MaskedFillInfo : public ArithmeticBase {
public:
MaskedFillInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &output_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, output_shape, attrs, std::make_shared<MaskedFillCost>()) {}
~MaskedFillInfo() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
protected:
Status GetAttrs() override;
Status InferTensorMap() override;
private:
size_t input_size_ = 0;
};
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,218 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <map>
#include <utility>
#include "frontend/parallel/ops_info/gamma_info.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "frontend/parallel/auto_parallel/edge_costmodel.h"
namespace mindspore {
namespace parallel {
Status GammaInfo::InferAttrs() {
if (infer_attrs_completed_) {
return SUCCESS;
}
if (GetAttrs() != SUCCESS) {
return FAILED;
}
ResetInputsShape();
infer_attrs_completed_ = true;
return SUCCESS;
}
Status GammaInfo::GetAttrs() {
seed_ = GetIntAttr(SEED);
if (seed_ < 0) {
MS_LOG(ERROR) << name_ << ": Seed must be greater or equal to zero, bug got " << seed_;
return FAILED;
}
seed2_ = GetIntAttr(SEED2);
if (seed2_ < 0) {
MS_LOG(ERROR) << name_ << ": Seed2 must be greater or equal to zero, bug got " << seed2_;
return FAILED;
}
return SUCCESS;
}
Status GammaInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED;
}
auto strategies = strategy->GetInputDim();
auto alpha_strategy = strategies.at(1);
auto beta_strategy = strategies.at(2);
if (!IsNotSplittableStrategy(alpha_strategy) || !IsNotSplittableStrategy(beta_strategy)) {
MS_LOG(ERROR) << name_ << ": Cannot shard the input `alpha` and `beta`, but got strategy "
<< StrategyToString(strategies);
return FAILED;
}
return SUCCESS;
}
Status GammaInfo::InferDevMatrixShape() {
dev_matrix_shape_.clear();
auto strategies = strategy_->GetInputDim();
auto shape_strategy = strategies.at(0);
dev_matrix_shape_ = shape_strategy;
MS_LOG(INFO) << name_ << ": The dev matrix is " << ShapeToString(dev_matrix_shape_);
return SUCCESS;
}
Status GammaInfo::InferTensorMap() {
auto strategies = strategy_->GetInputDim();
auto shape_strategy = strategies.at(0);
size_t size = shape_strategy.size();
TensorMap shape_tensor_map;
for (size_t i = 0; i < size; ++i) {
shape_tensor_map.push_back(size - i - 1);
}
inputs_tensor_map_.push_back(shape_tensor_map);
(void)inputs_tensor_map_.emplace_back(TensorMap(strategies.at(1).size(), -1));
(void)inputs_tensor_map_.emplace_back(TensorMap(strategies.at(2).size(), -1));
(void)outputs_tensor_map_.emplace_back(shape_tensor_map);
return SUCCESS;
}
std::vector<StrategyPtr> GammaInfo::GenerateOpStrategies(int64_t stage_id) {
Shape input0_split(inputs_shape_[0].size(), 1);
Shape input1_split(inputs_shape_[1].size(), 0);
Shape input2_split(inputs_shape_[2].size(), 0);
Shapes splittable_inputs = {input0_split, input1_split, input2_split};
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed.";
}
if (sp_vector.empty()) {
MS_LOG(EXCEPTION) << name_ << ": No available strategy.";
}
return sp_vector;
}
void GammaInfo::UpdateShape(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto input_node = cnode->input(1)->cast<ValueNodePtr>();
std::vector<int64_t> input_shape = GetValue<std::vector<int64_t>>(input_node->value());
std::vector<Dimensions> stra = strategy_->GetInputDim();
for (size_t i = 0; i < stra[0].size(); i++) {
input_shape[i] /= stra[0][i];
}
auto func_graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
ValuePtr new_shape = MakeValue(input_shape);
AnfNodePtr val = NewValueNode(new_shape);
(void)manager->Replace(cnode->input(1), val);
}
void GammaInfo::ReplaceNodeInputOrAttrs() {
// Replace input 'shape' to slice shape
auto cnode = cnode_;
UpdateShape(cnode);
// Update seed according rank_id
int64_t rank_id = g_device_manager->rank_index_in_stage();
int64_t seed_bias;
if (repeated_num_in_dev_matrix_right_) {
seed_bias = rank_id / repeated_calc_num_;
} else {
int64_t device_num = stage_device_size_;
seed_bias = rank_id % (device_num / repeated_calc_num_);
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
prim->set_attr(SEED, MakeValue(seed_ + seed_bias));
prim->set_attr(SEED2, MakeValue(seed2_ + seed_bias));
}
void GammaInfo::ResetInputsShape() {
if (inputs_shape_.size() == input_value_.size()) {
return;
}
ValueTuplePtr shape_value = input_value_[0]->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_value);
(void)inputs_shape_.insert(inputs_shape_.begin(), GetValue<Shape>(shape_value));
}
bool GammaInfo::IsNotSplittableStrategy(const Dimensions &strategy) {
return std::all_of(strategy.begin(), strategy.end(), [](int64_t val) { return val == 1; });
}
void GammaInfo::ReComputeBatchSplitFlagList() {
ResetInputsShape();
split_flag_list_.clear();
split_flag_list_ = std::vector<bool>(inputs_shape_.size(), false);
if (!split_flag_list_.empty()) {
split_flag_list_[0] = true;
}
}
int64_t GammaInfo::ComputeOpAndPrevEdgeParameterInvolved() {
if (is_output_parameter_involve_ != -1) {
return is_output_parameter_involve_;
}
is_parameter_involve_ = is_parameter_;
const auto &prev_edges = this->GetAlivePrevEdges();
for (auto &p_edge : prev_edges) {
auto input_index = p_edge->next_op_input_index();
auto prev_op_para = p_edge->prev_operator()->ComputeOpAndPrevEdgeParameterInvolved();
if (input_index - 1 >= is_parameter_involve_.size()) {
MS_LOG(EXCEPTION) << name_ << " has input length: " << is_parameter_involve_.size()
<< ", but got wrong input_index: " << input_index;
}
if (prev_op_para == 0) {
is_parameter_involve_[input_index - 1] = false;
} else if (prev_op_para == 1) {
is_parameter_involve_[input_index - 1] = true;
} else {
MS_LOG(EXCEPTION) << name_ << " got wrong value: " << prev_op_para << ", input_index: " << input_index;
}
p_edge->set_parameter_involve(prev_op_para);
}
if (std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; })) {
// If anyone of the input is a parameter_involved, the output is parameter_involved.
is_output_parameter_involve_ = 1;
} else {
is_output_parameter_involve_ = 0;
}
// Set 'is_parameter_involve_' and 'is_output_parameter_involve_' into operatorCost, which are used in
// calculating 'inputs_in_memory' and 'output_in_memory', respectively.
operator_cost()->set_is_parameter_involve(is_parameter_involve_);
operator_cost()->set_output_parameter_involve(is_output_parameter_involve_);
// Calculating 'output_in_memory'
operator_cost()->CalculateOutputInMemory();
// Calculating 'inputs_in_memory'
std::map<size_t, bool> input_in_memory;
for (auto &p_edge : prev_edges) {
auto input_index = p_edge->next_op_input_index();
auto is_in_mem = p_edge->prev_operator()->operator_cost()->is_output_in_memory();
(void)input_in_memory.emplace(std::make_pair(input_index, is_in_mem));
}
operator_cost()->CalculateInputsInMemory(input_in_memory);
return is_output_parameter_involve_;
}
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GAMMA_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GAMMA_INFO_H_
#include <string>
#include <memory>
#include <vector>
#include "utils/hash_map.h"
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
class GammaInfo : public OperatorInfo {
public:
GammaInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GammaCost>()) {}
~GammaInfo() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
void UpdateShape(const CNodePtr &cnode);
void ReplaceNodeInputOrAttrs() override;
void ReComputeBatchSplitFlagList() override;
int64_t ComputeOpAndPrevEdgeParameterInvolved() override;
protected:
Status InferAttrs() override;
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
private:
void ResetInputsShape();
bool IsNotSplittableStrategy(const Dimensions &strategy);
int64_t seed_ = 0;
int64_t seed2_ = 0;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GAMMA_INFO_H_

View File

@ -0,0 +1,136 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <utility>
#include "frontend/parallel/ops_info/kldiv_loss_info.h"
#include "frontend/parallel/graph_util/generate_graph.h"
namespace mindspore {
namespace parallel {
Status KLDivLossInfo::GetAttrs() {
reduction_ = GetStringAttr(REDUCTION);
return SUCCESS;
}
Status KLDivLossInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
return FAILED;
}
auto strategies = strategy->GetInputDim();
if (strategies[0] != strategies[1]) {
MS_LOG(ERROR) << name_ << ": The strategy of 'logits' and 'labels' must be the same, but got strategy "
<< StrategyToString(strategies);
return FAILED;
}
return SUCCESS;
}
Status KLDivLossInfo::InferDevMatrixShape() {
dev_matrix_shape_.clear();
auto strategies = strategy_->GetInputDim();
auto logits_strategy = strategies.at(0);
dev_matrix_shape_ = logits_strategy;
MS_LOG(INFO) << name_ << ": dev matrix is " << ShapeToString(dev_matrix_shape_);
return SUCCESS;
}
Status KLDivLossInfo::InferTensorMap() {
inputs_tensor_map_.clear();
outputs_tensor_map_.clear();
TensorMap sub_tensor_map;
auto strategies = strategy_->GetInputDim();
auto logits_strategy = strategies.at(0);
size_t size = logits_strategy.size();
for (size_t i = 0; i < size; ++i) {
sub_tensor_map.push_back(size - i - 1);
}
inputs_tensor_map_.push_back(sub_tensor_map);
inputs_tensor_map_.push_back(sub_tensor_map);
if (reduction_ == ATTR_NONE) {
(void)outputs_tensor_map_.emplace_back(sub_tensor_map);
} else {
(void)outputs_tensor_map_.emplace_back(TensorMap());
}
return SUCCESS;
}
std::vector<StrategyPtr> KLDivLossInfo::GenerateOpStrategies(int64_t stage_id) {
Shape splittable_input(inputs_shape_[0].size());
std::iota(splittable_input.begin(), splittable_input.end(), 1);
Shapes splittable_inputs(inputs_shape_.size(), splittable_input);
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForDependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Generate strategies for dependent inputs() failed.";
}
if (sp_vector.empty()) {
MS_LOG(EXCEPTION) << name_ << ": No available strategy.";
}
return sp_vector;
}
Status KLDivLossInfo::InferGroup() {
Shape group_create_map;
if (repeated_calc_num_ > 1) {
if (repeated_num_in_dev_matrix_right_) {
group_create_map.push_back(0);
} else {
group_create_map.push_back(SizeToLong(dev_matrix_shape_.size()) - 1);
}
}
if (CreateGroupByTensorMap(group_create_map, &group_list_) != SUCCESS) {
ReportError(name_ + ": Create group failed.");
return FAILED;
}
return SUCCESS;
}
Status KLDivLossInfo::InferForwardCommunication() {
if (reduction_ == ATTR_NONE) {
MS_LOG(DEBUG) << name_ << ": reduction is " << reduction_ << ", there is no need to append reduce op.";
return SUCCESS;
}
if (InferGroup() != SUCCESS) {
MS_LOG(ERROR) << "Infer group failed";
return FAILED;
}
if (group_list_.empty()) {
MS_LOG(INFO) << name_ << ": Forward all reduce is not required";
return SUCCESS;
}
forward_op_.clear();
if (reduction_ == MEAN) {
auto element_type = outputs_dtype_->cast<mindspore::TensorTypePtr>()->element();
forward_op_ = CreateReduceMeanForwardOp(group_list_, element_type);
} else {
(void)forward_op_.emplace_back(CreateAllReduceOp(REDUCE_OP_SUM, group_list_[0].name()));
}
MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list_[0].name();
return SUCCESS;
}
void KLDivLossInfo::ReComputeBatchSplitFlagList() { split_flag_list_.assign(inputs_shape_.size(), true); }
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_KLDIV_LOSS_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_KLDIV_LOSS_INFO_H_
#include <memory>
#include <vector>
#include <string>
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
namespace mindspore {
namespace parallel {
class KLDivLossInfo : public OperatorInfo {
public:
KLDivLossInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<KLDivLossCost>()) {}
~KLDivLossInfo() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
void ReComputeBatchSplitFlagList() override;
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferForwardCommunication() override;
private:
Status InferGroup();
std::string reduction_;
std::vector<Group> group_list_;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_KLDIV_LOSS_INFO_H_

View File

@ -0,0 +1,176 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <utility>
#include "frontend/parallel/ops_info/lin_space_info.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "frontend/parallel/graph_util/generate_graph.h"
namespace mindspore {
namespace parallel {
Status LinSpaceInfo::GetAttrs() {
auto output_0_shape = outputs_shape_.at(0);
output_size_ = output_0_shape.at(0);
return SUCCESS;
}
Status LinSpaceInfo::CheckStrategy(const StrategyPtr &strategy) {
auto strategies = strategy->GetInputDim();
if (strategies.size() != 1 && strategies[0].size() != 1) {
MS_LOG(ERROR) << name_ << ": The shape of input_strategy must be [1, 1], but got strategy "
<< StrategyToString(strategies);
return FAILED;
}
auto split_num = strategies[0][0];
if (output_size_ % split_num != 0) {
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategies) << ", output size is "
<< output_size_ << " cannot be divisible by strategy value " << split_num;
}
if (stage_device_size_ % split_num != 0) {
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategies)
<< ", the device size in this stage is " << stage_device_size_
<< " cannot be divisible by the strategy value " << split_num;
return FAILED;
}
return SUCCESS;
}
Status LinSpaceInfo::InferDevMatrixShape() {
dev_matrix_shape_.clear();
auto strategies = strategy_->GetInputDim();
auto split_strategy = strategies.at(0);
dev_matrix_shape_.push_back(split_strategy.at(0));
MS_LOG(INFO) << name_ << ": The device matrix is " << ShapeToString(dev_matrix_shape_);
return SUCCESS;
}
Status LinSpaceInfo::InferTensorMap() {
inputs_tensor_map_.clear();
outputs_tensor_map_.clear();
inputs_tensor_map_.assign(2, TensorMap{});
(void)outputs_tensor_map_.emplace_back(TensorMap{0});
return SUCCESS;
}
std::vector<StrategyPtr> LinSpaceInfo::GenerateOpStrategies(int64_t stage_id) {
Shape split_input(1, 1);
Shape split_shape(1, output_size_);
Shapes splittable_inputs = {split_input};
Shapes splittable_shapes = {split_shape};
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, splittable_shapes, splittable_inputs, &sp_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed.";
}
if (sp_vector.empty()) {
MS_LOG(EXCEPTION) << name_ << ": No available strategy.";
}
return sp_vector;
}
std::shared_ptr<Strategys> LinSpaceInfo::GenerateBatchStrategies() {
if (InferAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
}
int64_t dev_num = g_device_manager->stage_device_num();
Strategys strategies = {Dimensions{dev_num}};
return std::make_shared<Strategys>(strategies);
}
int64_t LinSpaceInfo::GetSplitNum() {
auto strategies = strategy_->GetInputDim();
auto split_strategy = strategies.at(0);
auto split_num = split_strategy.at(0);
return split_num;
}
ReplaceGraphPtr LinSpaceInfo::replace_graph(const CNodePtr &cnode) {
auto split_num = GetSplitNum();
if (split_num > 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
}
return replace_graph_;
}
Status LinSpaceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
auto split_num = GetSplitNum();
if (output_size_ % split_num != 0) {
return FAILED;
}
if (split_num == 1) {
MS_LOG(INFO) << name_ << ": split num is 1, no need to replace graph";
return SUCCESS;
}
GenerateGraph gen_g = GenerateGraph(attrs_);
if (gen_g.Init(cnode) != SUCCESS) {
MS_LOG(ERROR) << "GenerateGraph Init failed.";
return FAILED;
}
if (InferSliceId() != SUCCESS) {
MS_LOG(ERROR) << "Infer slice id failed.";
return FAILED;
}
int64_t slice_output_size = output_size_ / split_num;
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), gen_g.virtual_input_node()});
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), sub});
AnfNodePtr interval = nullptr;
if (output_size_ == 2) {
interval = sub;
} else {
auto interval_divsor = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(output_size_ - 2), dtype});
interval = gen_g.PushBack({gen_g.NewOpInst(DIV), sub, interval_divsor});
}
// new_start = start + slice_id * slice_output_size * interval
// new_end = new_start + (slice_output_size - 1) * interval
// new_x = slice_output_size
auto offset_size = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(slice_id_ * slice_output_size), dtype});
auto start_offset = gen_g.PushBack({gen_g.NewOpInst(MUL), interval, offset_size});
auto new_start = gen_g.PushBack({gen_g.NewOpInst(ADD), gen_g.virtual_input_node(), start_offset});
auto end_start_offset_size = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(slice_output_size - 1), dtype});
auto start_end_offset = gen_g.PushBack({gen_g.NewOpInst(MUL), interval, end_start_offset_size});
auto new_end = gen_g.PushBack({gen_g.NewOpInst(ADD), new_start, start_end_offset});
auto lin_space = gen_g.PushBack({gen_g.NewOpInst(LIN_SPACE), new_start, new_end, CreatInt64Imm(slice_output_size)});
std::vector<std::pair<AnfNodePtr, int64_t>> inputs_nodes = {std::make_pair(sub, 1), std::make_pair(sub, 2),
std::make_pair(new_start, 1)};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
std::make_pair(inputs_nodes, lin_space));
return SUCCESS;
}
Status LinSpaceInfo::InferSliceId() {
CheckGlobalDeviceManager();
int64_t rank = g_device_manager->rank_index_in_stage();
slice_id_ = rank;
if (repeated_calc_num_ > 1) {
if (repeated_num_in_dev_matrix_right_) {
slice_id_ /= dev_matrix_shape_.back();
} else {
slice_id_ %= dev_matrix_shape_.front();
}
}
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,62 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_LIN_SPACE_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_LIN_SPACE_INFO_H_
#include <string>
#include <memory>
#include <vector>
#include "utils/hash_map.h"
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
class LinSpaceInfo : public OperatorInfo {
public:
LinSpaceInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<LinSpaceCost>()) {}
~LinSpaceInfo() override = default;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferForwardCommunication() override { return SUCCESS; }
private:
Status InferSliceId();
Status ComputeReplaceGraph(const CNodePtr &cnode);
int64_t GetSplitNum();
int64_t slice_id_ = 0;
int64_t output_size_ = 0;
};
} // namespace parallel
} // namespace mindspore
#endif

View File

@ -2176,5 +2176,29 @@ AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple
auto tensor_node = NewValueNode(tensor_ptr);
return tensor_node->cast<AnfNodePtr>();
}
ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) {
// Create AllReduceSum op
Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name());
std::string group_name = forward_group[0].name();
MS_LOG(INFO) << "The group of forward all reduce is " << group_name;
// Create RealDiv op
OperatorName operator1_name = REAL_DIV;
std::vector<Device> device_list = forward_group[0].GetDevicesList();
auto divisor = static_cast<float>(device_list.size());
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(divisor, dtype);
ValuePtr op1_param_value = MakeValue(tensor_ptr);
Attr op1_param = std::make_pair("divisor", op1_param_value);
OperatorParams operator1_params = {std::make_pair(op1_param, 2)};
OperatorAttrs operator1_attrs;
OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params);
Operator op1 = std::make_pair(operator1_name, operator1_args);
ForwardOp forward_op = {op0, op1};
std::string dtype_name = dtype->ToString();
MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name;
return forward_op;
}
} // namespace parallel
} // namespace mindspore

View File

@ -112,7 +112,7 @@ class OperatorInfo {
// In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated
// by the output
Status CalculateMemoryCostForInference();
int64_t ComputeOpAndPrevEdgeParameterInvolved();
virtual int64_t ComputeOpAndPrevEdgeParameterInvolved();
ForwardOp forward_op() const { return forward_op_; }
ForwardOp replace_op() const { return replace_op_; }
@ -227,7 +227,7 @@ class OperatorInfo {
void SetRepeatedCalcDevMatrix();
void ResetTensorMapIfRepeatedCalc();
Status CreateGroupByDim(size_t axis, std::vector<Group> *group);
Status InferAttrs();
virtual Status InferAttrs();
void ResetQueueMember();
Status InitWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
@ -374,6 +374,8 @@ ValuePtr MakeListValue(const std::vector<int64_t> &v);
ValuePtr MakeTupleListValue(const Shapes &v);
AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector<int64_t> &value_tuple);
AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple);
ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype);
} // namespace parallel
} // namespace mindspore

View File

@ -70,5 +70,8 @@
#include "frontend/parallel/ops_info/addn_info.h"
#include "frontend/parallel/ops_info/inplace_add_info.h"
#include "frontend/parallel/ops_info/cdist_info.h"
#include "frontend/parallel/ops_info/gamma_info.h"
#include "frontend/parallel/ops_info/kldiv_loss_info.h"
#include "frontend/parallel/ops_info/lin_space_info.h"
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_

View File

@ -260,6 +260,9 @@ constexpr char POOLED_WIDTH[] = "pooled_width";
constexpr char SPATIAL_SCALE[] = "spatial_scale";
constexpr char SAMPLE_NUM[] = "sample_num";
constexpr char ROI_END_MODE[] = "roi_end_mode";
constexpr char REDUCTION[] = "reduction";
constexpr char MEAN[] = "mean";
constexpr char ATTR_NONE[] = "none";
// Operator
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
@ -477,6 +480,12 @@ constexpr char INPLACE_SUB[] = "InplaceSub";
constexpr char CDIST[] = "Cdist";
constexpr char L2_LOSS[] = "L2Loss";
constexpr char LERP[] = "Lerp";
constexpr char SQUARED_DIFFERENCE[] = "SquaredDifference";
constexpr char ERFINV[] = "Erfinv";
constexpr char SPLITV[] = "SplitV";
constexpr char GAMMA[] = "Gamma";
constexpr char KLDIV_LOSS[] = "KLDivLoss";
constexpr char LIN_SPACE[] = "LinSpace";
// pipeline
constexpr size_t PIPELINE_FUSTION_OFFSET = 100;

View File

@ -201,30 +201,6 @@ Status ReduceMethod::InferForwardCommunication() {
return SUCCESS;
}
ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) {
// Create AllReduceSum op
Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name());
std::string group_name = forward_group[0].name();
MS_LOG(INFO) << "The group of forward all reduce is " << group_name;
// Create RealDiv op
OperatorName operator1_name = REAL_DIV;
std::vector<Device> device_list = forward_group[0].GetDevicesList();
auto divisor = static_cast<float>(device_list.size());
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(divisor, dtype);
ValuePtr op1_param_value = MakeValue(tensor_ptr);
Attr op1_param = std::make_pair("divisor", op1_param_value);
OperatorParams operator1_params = {std::make_pair(op1_param, 2)};
OperatorAttrs operator1_attrs;
OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params);
Operator op1 = std::make_pair(operator1_name, operator1_args);
ForwardOp forward_op = {op0, op1};
std::string dtype_name = dtype->ToString();
MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name;
return forward_op;
}
Status ReduceMeanInfo::InferForwardCommunication() {
auto strategies = strategy_->GetInputDim();
Dimensions stra = strategies.at(0);
@ -801,7 +777,7 @@ Status SquareSumAllInfo::InferGroup() {
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
// it need to handle the first dimension of map.
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
group_creat_map.push_back(SizeToLong(dev_matrix_shape_.size() - size_t(1)));
}
for (size_t index = 0; index < size; ++index) {

View File

@ -30,9 +30,9 @@ namespace mindspore {
namespace parallel {
Status SplitInfo::GetAttrs() {
int64_t axis = 0;
int64_t output_num = 0;
auto axis_iter = attrs_.find(AXIS);
std::string attr_axis_name = GetSplitAxisAttrName();
auto axis_iter = attrs_.find(attr_axis_name);
if (axis_iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(axis_iter->second);
if (axis_iter->second->isa<Int64Imm>()) {
@ -56,21 +56,6 @@ Status SplitInfo::GetAttrs() {
}
axis_ = LongToSize(axis);
auto output_num_iter = attrs_.find(OUTPUT_NUM);
if (output_num_iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(output_num_iter->second);
if (output_num_iter->second->isa<Int64Imm>()) {
output_num = output_num_iter->second->cast<Int64ImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << ": The value of output_num is not int";
return FAILED;
}
} else {
MS_LOG(ERROR) << name_ << ": Can not find the output_num attr";
return FAILED;
}
output_num_ = LongToSize(output_num);
return SUCCESS;
}
@ -152,6 +137,9 @@ std::vector<StrategyPtr> SplitInfo::GenerateOpStrategies(int64_t stage_id) {
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";
}
if (sp_vector.empty()) {
MS_LOG(EXCEPTION) << name_ << ": No available strategy";
}
return sp_vector;
}

View File

@ -46,10 +46,21 @@ class SplitInfo : public OperatorInfo {
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferAsLossDivisor() override;
virtual std::string GetSplitAxisAttrName() const { return AXIS; }
private:
size_t axis_ = 0;
size_t output_num_ = 0;
};
class SplitVInfo : public SplitInfo {
public:
SplitVInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: SplitInfo(operator_name, inputs_shape, outputs_shape, attrs) {}
~SplitVInfo() override = default;
protected:
std::string GetSplitAxisAttrName() const override { return SPLIT_DIM; };
};
} // namespace parallel
} // namespace mindspore

View File

@ -28,6 +28,19 @@
namespace mindspore {
namespace parallel {
Status UniformRealInfo::InferAttrs() {
if (infer_attrs_completed_) {
return SUCCESS;
}
if (GetAttrs() != SUCCESS) {
return FAILED;
}
ResetInputsShape();
infer_attrs_completed_ = true;
return SUCCESS;
}
Status UniformRealInfo::GetAttrs() {
seed_ = GetIntAttr(SEED);
if (seed_ < 0) {
@ -39,10 +52,6 @@ Status UniformRealInfo::GetAttrs() {
MS_LOG(ERROR) << name_ << ": Seed2 must be greater or equal to zero, bug got " << seed2_;
return FAILED;
}
ValueTuplePtr shape_value = input_value_[0]->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_value);
inputs_shape_.push_back(GetValue<Shape>(shape_value));
return SUCCESS;
}
@ -97,7 +106,10 @@ std::vector<StrategyPtr> UniformRealInfo::GenerateOpStrategies(int64_t stage_id)
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed.";
}
if (sp_vector.empty()) {
MS_LOG(EXCEPTION) << name_ << ": No available strategy.";
}
return sp_vector;
}
@ -120,23 +132,29 @@ void UniformRealInfo::UpdateShape(const CNodePtr &cnode) {
}
void UniformRealInfo::ReplaceNodeInputOrAttrs() {
// Replace input 'shape' to slice shape
auto cnode = cnode_;
int64_t split_num = 1;
UpdateShape(cnode);
std::vector<Dimensions> stra = strategy_->GetInputDim();
for (size_t i = 0; i < stra[0].size(); i++) {
split_num *= stra[0][i];
}
int64_t device_num = stage_device_size_;
int64_t rank_id = g_device_manager->rank_index_in_stage();
if (device_num != split_num) {
auto split_group_num = device_num / split_num;
int64_t seed_bias = rank_id / split_group_num;
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
prim->set_attr(SEED, MakeValue(seed_ + seed_bias));
prim->set_attr(SEED2, MakeValue(seed2_ + seed_bias));
// Update seed according rank_id
int64_t rank_id = g_device_manager->rank_index_in_stage();
int64_t seed_bias;
if (repeated_num_in_dev_matrix_right_) {
seed_bias = rank_id / repeated_calc_num_;
} else {
int64_t device_num = stage_device_size_;
seed_bias = rank_id % (device_num / repeated_calc_num_);
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
prim->set_attr(SEED, MakeValue(seed_ + seed_bias));
prim->set_attr(SEED2, MakeValue(seed2_ + seed_bias));
}
void UniformRealInfo::ResetInputsShape() {
ValueTuplePtr shape_value = input_value_[0]->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_value);
inputs_shape_.push_back(GetValue<Shape>(shape_value));
}
} // namespace parallel
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORMREAL_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORMREAL_INFO_H_
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORM_REAL_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORM_REAL_INFO_H_
#include <string>
#include <memory>
@ -33,7 +33,7 @@ class UniformRealInfo : public OperatorInfo {
public:
UniformRealInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<MaxPoolCost>()) {}
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<UniformRealCost>()) {}
~UniformRealInfo() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
@ -42,6 +42,7 @@ class UniformRealInfo : public OperatorInfo {
void ReplaceNodeInputOrAttrs() override;
protected:
Status InferAttrs() override;
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override { return SUCCESS; }
@ -49,11 +50,12 @@ class UniformRealInfo : public OperatorInfo {
Status InferTensorMap() override;
private:
void ResetInputsShape();
int64_t seed_ = 0;
int64_t seed2_ = 0;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORMREAL_INFO_H_
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORM_REAL_INFO_H_

View File

@ -71,8 +71,9 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) {
max_size_strategy_dim_ = i;
}
}
if (std::find(stra[max_size_strategy_dim_].begin(), stra[max_size_strategy_dim_].end(), shard_num_) ==
stra[max_size_strategy_dim_].end()) {
if (!stra[max_size_strategy_dim_].empty() &&
std::find(stra[max_size_strategy_dim_].begin(), stra[max_size_strategy_dim_].end(), shard_num_) ==
stra[max_size_strategy_dim_].end()) {
MS_LOG(ERROR) << name_
<< ": For each dataset input, the shard strategy can be not shard, "
"or shard in one dim with the same shard size between each input."
@ -96,7 +97,7 @@ Status VirtualDatasetInfo::InferForwardCommunication() { return SUCCESS; }
Status VirtualDatasetInfo::InferTensorMap() {
auto dev_mat_origin = strategy_->GetInputDim()[max_size_strategy_dim_];
auto slice_dim_iter = std::find(dev_mat_origin.begin(), dev_mat_origin.end(), shard_num_);
if (slice_dim_iter == dev_mat_origin.end()) {
if (!dev_mat_origin.empty() && slice_dim_iter == dev_mat_origin.end()) {
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim.";
return FAILED;
}
@ -108,8 +109,7 @@ Status VirtualDatasetInfo::InferTensorMap() {
if (dim == 1) {
tensor_map_index.push_back(MAP_NONE);
} else if (dim == shard_num_) {
if (repeated_num_in_dev_matrix_right_ && dev_matrix_shape_.size() != dev_mat_origin.size() &&
is_auto_parallel_) {
if (repeated_calc_num_ > 1 && repeated_num_in_dev_matrix_right_ && is_auto_parallel_) {
tensor_map_index.push_back(dev_mat_origin.size() - slice_dim);
} else {
tensor_map_index.push_back(dev_mat_origin.size() - 1 - slice_dim);
@ -176,8 +176,10 @@ Status VirtualDatasetInfo::GenerateStrategies(int64_t stage_id) {
}
for (auto &shape : inputs_shape_) {
Shape temp;
temp.emplace_back(SizeToLong(total_dev_num));
(void)temp.insert(temp.end(), shape.size() - 1, 1);
if (!shape.empty()) {
temp.emplace_back(SizeToLong(total_dev_num));
(void)temp.insert(temp.end(), shape.size() - 1, 1);
}
strategy.push_back(temp);
}
}

View File

@ -179,7 +179,7 @@ bool IsSplittableOperator(const std::string &op_name) {
RESIZE_NEAREST_NEIGHBOR, CUM_SUM, FAST_GELU, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE,
ROI_ALIGN, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH, SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND,
BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, L2_LOSS, LERP, ADDN,
CDIST};
CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE};
// clang-format on
auto iter = splittable_op.find(op_name);

View File

@ -1809,18 +1809,19 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
std::vector<ValuePtr> elements;
for (size_t i = 0; i < shape_list[0].size(); i++) {
if (shape_list[0][i].empty()) {
MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero";
(void)elements.emplace_back(MakeValue(Dimensions()));
continue;
}
Dimensions input_strategy;
if (!shape_list[0][i].empty() && shape_list[0][i][0] % dev_num == 0) {
if (shape_list[0][i][0] % dev_num == 0) {
input_strategy.push_back(dev_num);
} else if (!shape_list[0][i].empty()) {
} else {
input_strategy.push_back(1);
}
for (size_t j = 1; j < shape_list[0][i].size(); j++) {
input_strategy.push_back(1);
}
elements.push_back(MakeValue(input_strategy));
(void)elements.emplace_back(MakeValue(input_strategy));
}
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
attrs_temp[IN_STRATEGY] = strategy;

View File

@ -1110,3 +1110,89 @@ def test_matmul_xlogy_broadcast():
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_squared_difference_broadcast():
"""
Feature: distribute operator SquaredDifference in auto parallel.
Description: mul-SquaredDifference net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().shard(strategy1)
self.squared_difference = P.SquaredDifference().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.squared_difference(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((2, 4), (4, 1))
strategy2 = ((4, 1), (1, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_masked_fill_broadcast_with_value_float():
"""
Feature: distribute operator MaskedFill in auto parallel.
Description: mul-MaskedFill net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().shard(strategy1)
self.masked_fill = P.MaskedFill().shard(strategy2)
self.value = 1.0
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.masked_fill(out, b, self.value)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((2, 4), (4, 1))
strategy2 = ((4, 1), (1, 2))
net = Net(strategy1, strategy2)
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.bool_)
compile_net(net, x, y, b)
def test_matmul_masked_fill_broadcast_with_value_tensor():
"""
Feature: distribute operator MaskedFill in auto parallel.
Description: mul-MaskedFill net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().shard(strategy1)
self.masked_fill = P.MaskedFill().shard(strategy2)
self.value = Tensor(1.0, ms.float32)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.masked_fill(out, b, self.value)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((2, 4), (4, 1))
strategy2 = ((4, 1), (1, 2), ())
net = Net(strategy1, strategy2)
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.bool_)
compile_net(net, x, y, b)

View File

@ -1173,3 +1173,34 @@ def test_matmul_soft_shrink():
y = Tensor(np.random.uniform(-5, 5, size=(32, 64)), dtype=ms.float32)
b = Tensor(np.random.uniform(-5, 5, size=(64, 64)), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_erfinv():
"""
Feature: distribute operator Erfinv in auto parallel.
Description: matmul-squared_difference-matmul net with strategy in semi auto parallel.
Expectation: compile done without error.
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().shard(strategy1)
self.erfinv = P.Erfinv().shard(strategy2)
self.matmul2 = P.MatMul().shard(strategy1)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.erfinv(out)
out = self.matmul2(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2),)
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.random.uniform(-5, 5, size=(128, 32)), dtype=ms.float32)
y = Tensor(np.random.uniform(-5, 5, size=(32, 64)), dtype=ms.float32)
b = Tensor(np.random.uniform(-5, 5, size=(64, 64)), dtype=ms.float32)
compile_net(net, x, y, b)

View File

@ -0,0 +1,94 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore.nn import Cell
from mindspore.ops import operations as P
from parallel.utils.utils import ParallelValidator, compile_net
SEED_ = 1
SEED2_ = 1
alpha_ = Tensor(np.array([1.0]), mstype.float32)
beta_ = Tensor(np.array([1.0]), mstype.float32)
class Net(Cell):
def __init__(self, seed, seed2, strategy=None):
super(Net, self).__init__()
self.gamma = P.Gamma(seed, seed2).shard(strategy)
def construct(self, shape, alpha, beta):
out = self.gamma(shape, alpha, beta)
return out
def test_gamma_auto_parallel():
"""
Features: test Gamma auto parallel
Description: auto parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
net = Net(SEED_, SEED2_)
shape = (4, 4, 4)
compile_net(net, shape, alpha_, beta_)
def test_gamma_data_parallel():
"""
Features: test Gamma data parallel
Description: data parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
net = Net(SEED_, SEED2_)
shape = (8, 8)
phase = compile_net(net, shape, alpha_, beta_)
validator = ParallelValidator(net, phase)
assert validator.check_node_attrs("Gamma-0", {"seed": 2, "seed2": 2})
def test_gamma_model_parallel():
"""
Features: test Gamma model parallel
Description: model parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
shape = (8, 8)
strategy = ((2, 2), (1,), (1,))
net = Net(SEED_, SEED2_, strategy)
phase = compile_net(net, shape, alpha_, beta_)
validator = ParallelValidator(net, phase)
assert validator.check_node_attrs("Gamma-0", {"seed": 3, "seed2": 3})
def test_gamma_strategy_error():
"""
Features:test Gamma strategy error
Description: invalid strategy
Expectation: Raise RuntimeError
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
shape = (8, 8)
strategy = ((2, 2), (2,), (1,))
net = Net(SEED_, SEED2_, strategy)
with pytest.raises(RuntimeError):
compile_net(net, shape, alpha_, beta_)

View File

@ -0,0 +1,127 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore.nn import Cell
from mindspore.ops import operations as P
from parallel.utils.utils import ParallelValidator, compile_net
logits_ = Tensor(np.random.uniform(0, 1, [8, 8]), mstype.float32)
labels_ = Tensor(np.random.randint(0, 10, [8, 8]), mstype.float32)
class Net(Cell):
def __init__(self, reduction, strategy=None):
super(Net, self).__init__()
self.kldiv_loss = P.KLDivLoss(reduction).shard(strategy)
def construct(self, logits, labels):
out = self.kldiv_loss(logits, labels)
return out
def test_kldiv_loss_mean_auto_parallel():
"""
Features: test KLDivLoss auto parallel
Description: auto parallel, reduction is 'mean'
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
reduction = 'mean'
net = Net(reduction)
compile_net(net, logits_, labels_)
def test_kldiv_loss_none_auto_parallel():
"""
Features: test KLDivLoss auto parallel
Description: auto parallel, reduction is 'none'
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
reduction = 'none'
net = Net(reduction)
compile_net(net, logits_, labels_)
def test_kldiv_loss_sum_auto_parallel():
"""
Features: test KLDivLoss auto parallel
Description: auto parallel, reduction is 'sum'
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
reduction = 'sum'
net = Net(reduction)
compile_net(net, logits_, labels_)
def test_kldiv_loss_mean_data_parallel():
"""
Features: test KLDivLoss data parallel
Description: data parallel, reduction is 'mean'
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
reduction = 'mean'
net = Net(reduction)
phase = compile_net(net, logits_, labels_)
validator = ParallelValidator(net, phase)
assert validator.check_node_inputs('AllReduce-0', ['KLDivLoss-0'])
assert validator.check_node_attrs('AllReduce-0', {'op': 'sum'})
def test_kldiv_loss_none_data_parallel():
"""
Features: test KLDivLoss data parallel
Description: data parallel, reduction is 'none'
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
reduction = 'none'
net = Net(reduction)
compile_net(net, logits_, labels_)
def test_kldiv_loss_none_model_parallel():
"""
Features: test KLDivLoss model parallel
Description: model parallel, reduction is 'none'
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
reduction = 'none'
strategy = ((2, 2), (2, 2))
net = Net(reduction, strategy)
compile_net(net, logits_, labels_)
def test_kldiv_loss_mean_model_parallel():
"""
Features: test KLDivLoss model parallel
Description: model parallel, reduction is 'mean'
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
reduction = 'mean'
strategy = ((4, 2), (4, 2))
net = Net(reduction, strategy)
phase = compile_net(net, logits_, labels_)
validator = ParallelValidator(net, phase)
assert validator.check_node_inputs('AllReduce-0', ['KLDivLoss-0'])
assert validator.check_node_attrs('AllReduce-0', {'op': 'sum'})

View File

@ -18,7 +18,7 @@ from mindspore import Tensor, context
from mindspore.nn import Cell
from mindspore.ops import operations as P
from parallel.utils.utils import compile_net
from parallel.utils.utils import ParallelValidator, compile_net
x_ = Tensor(np.random.normal(size=[32, 8, 8]).astype(np.float32))
@ -52,4 +52,21 @@ def test_l2_loss_model_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((2, 2, 2),)
net = Net(strategy)
compile_net(net, x_)
phase = compile_net(net, x_)
validator = ParallelValidator(net, phase)
assert validator.check_node_inputs('AllReduce-0', ['L2Loss-0'])
assert validator.check_node_attrs('AllReduce-0', {'op': 'sum'})
def test_l2_loss_data_parallel():
"""
Feature: test L2Loss data parallel
Description: data parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
net = Net()
phase = compile_net(net, x_)
validator = ParallelValidator(net, phase)
assert validator.check_node_inputs('AllReduce-0', ['L2Loss-0'])
assert validator.check_node_attrs('AllReduce-0', {'op': 'sum'})

View File

@ -0,0 +1,111 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore.nn import Cell
from mindspore.ops import operations as P
from parallel.utils.utils import ParallelValidator, compile_net
class Net(Cell):
def __init__(self, strategy=None):
super(Net, self).__init__()
self.lin_space = P.LinSpace().shard(strategy)
def construct(self, start, end, x):
return self.lin_space(start, end, x)
def test_lin_space_loss_auto_parallel():
"""
Feature: test LinSpace auto parallel
Description: auto parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
start = Tensor(1, mstype.float32)
end = Tensor(10, mstype.float32)
x = 8
net = Net()
compile_net(net, start, end, x)
def test_lin_space_parallel_with_repeated_cal():
"""
Feature: test LinSpace parallel
Description: parallel with repeated_cal
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
start = Tensor(1, mstype.float32)
end = Tensor(10, mstype.float32)
x = 8
strategy = ((4,),)
net = Net(strategy)
phase = compile_net(net, start, end, x)
validator = ParallelValidator(net, phase)
assert validator.check_node_inputs('LinSpace-0', ['Add-0', 'Add-1', 2])
def test_lin_space_parallel_with_x_2():
"""
Feature: test LinSpace parallel
Description: parallel with input x is 2
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
start = Tensor(1, mstype.float32)
end = Tensor(10, mstype.float32)
x = 2
strategy = ((2,),)
net = Net(strategy)
phase = compile_net(net, start, end, x)
validator = ParallelValidator(net, phase)
assert validator.check_node_inputs('LinSpace-0', ['Add-0', 'Add-1', 1])
def test_lin_space_data_parallel():
"""
Feature: test LinSpace data parallel
Description: data parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
start = Tensor(1, mstype.float32)
end = Tensor(10, mstype.float32)
x = 8
net = Net()
phase = compile_net(net, start, end, x)
validator = ParallelValidator(net, phase)
assert validator.check_node_inputs('LinSpace-0', ['Add-0', 'Add-1', 1])
def test_lin_space_parallel_strategy_wrong():
"""
Feature: test LinSpace parallel with invalid strategy
Description: invalid strategy
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
start = Tensor(1, mstype.float32)
end = Tensor(10, mstype.float32)
x = 6
strategy = ((4,),)
net = Net(strategy)
with pytest.raises(RuntimeError):
compile_net(net, start, end, x)

View File

@ -0,0 +1,118 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.context as context
from mindspore import Tensor, Parameter
import mindspore.nn as nn
from mindspore.ops import operations as P
from parallel.utils.utils import compile_net
input_x_tensor_ = Tensor(np.ones([8, 8, 8]), ms.float32)
input_x_parameter_ = Parameter(Tensor(np.ones([8, 8, 8]), ms.float32), "input_x")
SIZE_SPLIT = [3, 3, 2]
NUM_SPLIT = 3
class Net(nn.Cell):
def __init__(self, size_split, split_dim, num_split, strategy1=None, strategy2=None):
super(Net, self).__init__()
self.splitv = P.SplitV(size_split, split_dim, num_split).shard(strategy1)
self.mul = P.Mul().shard(strategy2)
self.weight = Parameter(np.array([1.0]), name="mul_weight")
def construct(self, x):
out = self.splitv(x)
out = self.mul(out[0], self.weight)
return out
class NetWithParameter(nn.Cell):
def __init__(self, size_split, split_dim, num_split, strategy1=None, strategy2=None):
super(NetWithParameter, self).__init__()
self.splitv = P.SplitV(size_split, split_dim, num_split).shard(strategy1)
self.mul = P.Mul().shard(strategy2)
self.weight = input_x_parameter_
def construct(self, x):
out = self.splitv(self.weight)
out = self.mul(x, out[0])
return out
def test_splitv_auto_parallel():
"""
Feature: test SplitV auto parallel
Description: auto parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(SIZE_SPLIT, 0, NUM_SPLIT)
compile_net(net, input_x_tensor_)
def test_splitv_auto_parallel_with_parameter():
"""
Feature: test SplitV auto parallel with parameter input
Description: auto parallel with parameter input
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = NetWithParameter(SIZE_SPLIT, 2, NUM_SPLIT)
x = Tensor(np.ones([8, 8, 3]), ms.float32)
compile_net(net, x)
def test_splitv_data_parallel():
"""
Feature: test SplitV data parallel
Description: data parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
net = Net(SIZE_SPLIT, 1, NUM_SPLIT)
compile_net(net, input_x_tensor_)
def test_splitv_model_parallel():
"""
Feature: test SplitV model parallel
Description: model parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 2, 2),)
strategy2 = ((1, 2, 2), (1,))
net = Net(SIZE_SPLIT, 0, NUM_SPLIT, strategy1, strategy2)
compile_net(net, input_x_tensor_)
def test_splitv_strategy_error():
"""
Feature: test SplitV parallel with invalid strategy
Description: config invalid strategy
Expectation: raise RuntimeError
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 2),)
strategy2 = ((8, 1, 1),)
net = Net(SIZE_SPLIT, 0, NUM_SPLIT, strategy1, strategy2)
with pytest.raises(RuntimeError):
compile_net(net, input_x_tensor_)
context.reset_auto_parallel_context()

View File

@ -0,0 +1,74 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mindspore import context
from mindspore.nn import Cell
from mindspore.ops import operations as P
from parallel.utils.utils import ParallelValidator, compile_net
SEED_ = 1
SEED2_ = 1
class Net(Cell):
def __init__(self, seed, seed2, strategy=None):
super(Net, self).__init__()
self.uniform_real = P.UniformReal(seed, seed2).shard(strategy)
def construct(self, shape):
out = self.uniform_real(shape)
return out
def test_uniform_real_auto_parallel():
"""
Features: test UniformReal auto parallel
Description: auto parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(SEED_, SEED2_)
shape = (4, 4, 4)
compile_net(net, shape)
def test_uniform_real_data_parallel():
"""
Features: test UniformReal data parallel
Description: data parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
net = Net(SEED_, SEED2_)
shape = (8, 8)
phase = compile_net(net, shape)
validator = ParallelValidator(net, phase)
assert validator.check_node_attrs("UniformReal-0", {"seed": 2, "seed2": 2})
def test_uniform_real_model_parallel():
"""
Features: test UniformReal model parallel
Description: model parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
shape = (8, 8)
strategy = ((2, 2),)
net = Net(SEED_, SEED2_, strategy)
phase = compile_net(net, shape)
validator = ParallelValidator(net, phase)
assert validator.check_node_attrs("UniformReal-0", {"seed": 3, "seed2": 3})