forked from mindspore-Ecosystem/mindspore
!31988 Implementation of SquaredDifferenceInfo, ErfinvInfo, MaskedFillInfo, SplitVInfo, GammaInfo, KLDivLossInfo and LinSpaceInfo.
Merge pull request !31988 from liuluobin/ops_impl
This commit is contained in:
commit
e897c98b1f
|
@ -192,6 +192,8 @@ using IOUCost = CastCost;
|
||||||
using RandomChoicWithMaskCost = CastCost;
|
using RandomChoicWithMaskCost = CastCost;
|
||||||
using IsFiniteCost = CastCost;
|
using IsFiniteCost = CastCost;
|
||||||
using RintCost = CastCost;
|
using RintCost = CastCost;
|
||||||
|
using GammaCost = CastCost;
|
||||||
|
using LinSpaceCost = CastCost;
|
||||||
|
|
||||||
class SqrtCost : public CastCost {
|
class SqrtCost : public CastCost {
|
||||||
public:
|
public:
|
||||||
|
@ -247,6 +249,7 @@ using ErfcCost = ReLU6Cost;
|
||||||
using ActivationInfoCost = ReLU6Cost;
|
using ActivationInfoCost = ReLU6Cost;
|
||||||
using SelectCost = ReLU6Cost;
|
using SelectCost = ReLU6Cost;
|
||||||
using XlogyCost = ReLU6Cost;
|
using XlogyCost = ReLU6Cost;
|
||||||
|
using ErfinvCost = ReLU6Cost;
|
||||||
|
|
||||||
class TransposeCost : public CastCost {
|
class TransposeCost : public CastCost {
|
||||||
public:
|
public:
|
||||||
|
@ -634,6 +637,7 @@ using BitwiseOrCost = SubCost;
|
||||||
using BitwiseXorCost = SubCost;
|
using BitwiseXorCost = SubCost;
|
||||||
using AddNCost = SubCost;
|
using AddNCost = SubCost;
|
||||||
using InplaceAddCost = SubCost;
|
using InplaceAddCost = SubCost;
|
||||||
|
using MaskedFillCost = SubCost;
|
||||||
|
|
||||||
class MulCost : public SubCost {
|
class MulCost : public SubCost {
|
||||||
public:
|
public:
|
||||||
|
@ -646,6 +650,7 @@ class MulCost : public SubCost {
|
||||||
using MulNoNanCost = MulCost;
|
using MulNoNanCost = MulCost;
|
||||||
using GatherDCost = MulCost;
|
using GatherDCost = MulCost;
|
||||||
using LerpCost = MulCost;
|
using LerpCost = MulCost;
|
||||||
|
using SquaredDifferenceCost = MulCost;
|
||||||
|
|
||||||
class DivCost : public SubCost {
|
class DivCost : public SubCost {
|
||||||
public:
|
public:
|
||||||
|
@ -777,6 +782,7 @@ using ReduceMethodCost = ReduceSumCost;
|
||||||
using ReduceProdCost = ReduceSumCost;
|
using ReduceProdCost = ReduceSumCost;
|
||||||
using SquareSumAllCost = ReduceSumCost;
|
using SquareSumAllCost = ReduceSumCost;
|
||||||
using L2LossCost = ReduceSumCost;
|
using L2LossCost = ReduceSumCost;
|
||||||
|
using KLDivLossCost = ReduceSumCost;
|
||||||
|
|
||||||
class ReduceMeanCost : public ReduceSumCost {
|
class ReduceMeanCost : public ReduceSumCost {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -243,6 +243,13 @@ REGISTER(InplaceSubInfo);
|
||||||
REGISTER(CdistInfo);
|
REGISTER(CdistInfo);
|
||||||
REGISTER(L2LossInfo);
|
REGISTER(L2LossInfo);
|
||||||
REGISTER(LerpInfo);
|
REGISTER(LerpInfo);
|
||||||
|
REGISTER(SquaredDifferenceInfo);
|
||||||
|
REGISTER(ErfinvInfo);
|
||||||
|
REGISTER(MaskedFillInfo);
|
||||||
|
REGISTER(SplitVInfo);
|
||||||
|
REGISTER(GammaInfo);
|
||||||
|
REGISTER(KLDivLossInfo);
|
||||||
|
REGISTER(LinSpaceInfo);
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -91,6 +91,19 @@ AnfNodePtr CreateInt32Tensor(int64_t value) {
|
||||||
return anf_node_ptr;
|
return anf_node_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static mindspore::HashMap<float, AnfNodePtr> float_tensor_map = {};
|
||||||
|
AnfNodePtr CreateFP32Tensor(float value) {
|
||||||
|
auto it = float_tensor_map.find(value);
|
||||||
|
if (it != float_tensor_map.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(value, kFloat32);
|
||||||
|
ValuePtr value_ptr = MakeValue(tensor_ptr);
|
||||||
|
auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr);
|
||||||
|
float_tensor_map[value] = anf_node_ptr;
|
||||||
|
return anf_node_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr CreateTypeInt(int64_t nbits) {
|
AnfNodePtr CreateTypeInt(int64_t nbits) {
|
||||||
ValuePtr value_ptr = MakeValue(std::make_shared<Int>(nbits));
|
ValuePtr value_ptr = MakeValue(std::make_shared<Int>(nbits));
|
||||||
return ValuePtrToAnfNodePtr(value_ptr);
|
return ValuePtrToAnfNodePtr(value_ptr);
|
||||||
|
|
|
@ -42,6 +42,7 @@ AnfNodePtr CreateTypeFloat(int64_t nbits);
|
||||||
AnfNodePtr CreatInt64Imm(int64_t value);
|
AnfNodePtr CreatInt64Imm(int64_t value);
|
||||||
AnfNodePtr CreateFP32Imm(float value);
|
AnfNodePtr CreateFP32Imm(float value);
|
||||||
AnfNodePtr CreateInt32Tensor(int64_t value);
|
AnfNodePtr CreateInt32Tensor(int64_t value);
|
||||||
|
AnfNodePtr CreateFP32Tensor(float value);
|
||||||
AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr);
|
AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr);
|
||||||
AnfNodePtr CreateTuple(const std::vector<int64_t> &tuple);
|
AnfNodePtr CreateTuple(const std::vector<int64_t> &tuple);
|
||||||
std::string HashInstanceName(const std::string &name);
|
std::string HashInstanceName(const std::string &name);
|
||||||
|
|
|
@ -76,7 +76,16 @@ std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
|
||||||
} else {
|
} else {
|
||||||
inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
|
inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
|
||||||
}
|
}
|
||||||
return std::vector<bool>(inputs_seq.size(), false);
|
size_t inputs_seq_tensor_size = inputs_seq.size();
|
||||||
|
for (const auto &inputs_seq_value : inputs_seq) {
|
||||||
|
auto tensor = inputs_seq_value->cast<tensor::TensorPtr>();
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "The value not is not a tensor.";
|
||||||
|
inputs_seq_tensor_size = 0;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::vector<bool>(inputs_seq_tensor_size, false);
|
||||||
}
|
}
|
||||||
if ((node_inputs.size() == 2) &&
|
if ((node_inputs.size() == 2) &&
|
||||||
(AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
|
(AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
|
||||||
|
@ -168,7 +177,10 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
|
||||||
}
|
}
|
||||||
for (auto &ele : inputs_seq) {
|
for (auto &ele : inputs_seq) {
|
||||||
auto tensor = ele->cast<tensor::TensorPtr>();
|
auto tensor = ele->cast<tensor::TensorPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(tensor);
|
if (tensor == nullptr) {
|
||||||
|
inputs_type_len.clear();
|
||||||
|
return inputs_type_len;
|
||||||
|
}
|
||||||
inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype()));
|
inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype()));
|
||||||
}
|
}
|
||||||
return inputs_type_len;
|
return inputs_type_len;
|
||||||
|
|
|
@ -28,6 +28,7 @@
|
||||||
#include "frontend/parallel/device_matrix.h"
|
#include "frontend/parallel/device_matrix.h"
|
||||||
#include "frontend/parallel/strategy.h"
|
#include "frontend/parallel/strategy.h"
|
||||||
#include "frontend/parallel/tensor_layout/redistribution_operator_infer.h"
|
#include "frontend/parallel/tensor_layout/redistribution_operator_infer.h"
|
||||||
|
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
|
@ -616,5 +617,33 @@ Status L2LossInfo::InferTensorMap() {
|
||||||
outputs_tensor_map_[0].clear();
|
outputs_tensor_map_[0].clear();
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status L2LossInfo::InferForwardCommunication() {
|
||||||
|
forward_op_.clear();
|
||||||
|
Shape group_create_map;
|
||||||
|
if (repeated_calc_num_ > 1) {
|
||||||
|
if (repeated_num_in_dev_matrix_right_) {
|
||||||
|
group_create_map.push_back(0);
|
||||||
|
} else {
|
||||||
|
group_create_map.push_back(SizeToLong(dev_matrix_shape_.size()) - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Group> group_list;
|
||||||
|
if (CreateGroupByTensorMap(group_create_map, &group_list) != SUCCESS) {
|
||||||
|
ReportError(name_ + ": Create group failed.");
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (group_list.empty()) {
|
||||||
|
MS_LOG(INFO) << name_ << ": Forward all reduce is not required";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operator op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
|
||||||
|
forward_op_.push_back(op);
|
||||||
|
MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list[0].name();
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -374,13 +374,22 @@ class SoftShrinkInfo : public ActivationOther {
|
||||||
|
|
||||||
class L2LossInfo : public ActivationOther {
|
class L2LossInfo : public ActivationOther {
|
||||||
public:
|
public:
|
||||||
L2LossInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &output_shape,
|
L2LossInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
const PrimitiveAttrs &attrs)
|
const PrimitiveAttrs &attrs)
|
||||||
: ActivationOther(name, inputs_shape, output_shape, attrs, std::make_shared<L2LossCost>()) {}
|
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<L2LossCost>()) {}
|
||||||
~L2LossInfo() = default;
|
~L2LossInfo() = default;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status InferTensorMap() override;
|
Status InferTensorMap() override;
|
||||||
|
Status InferForwardCommunication() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ErfinvInfo : public ActivationOther {
|
||||||
|
public:
|
||||||
|
ErfinvInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ErfinvCost>()) {}
|
||||||
|
~ErfinvInfo() override = default;
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -353,5 +353,39 @@ void LerpInfo::ReComputeBatchSplitFlagList() {
|
||||||
Shape expand_weight_shape = ExpandShape(expand_a_shape, inputs_shape_.at(2));
|
Shape expand_weight_shape = ExpandShape(expand_a_shape, inputs_shape_.at(2));
|
||||||
(expand_weight_shape.at(0) != 1) ? (split_flag_list_[2] = true) : (split_flag_list_[2] = false);
|
(expand_weight_shape.at(0) != 1) ? (split_flag_list_[2] = true) : (split_flag_list_[2] = false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status MaskedFillInfo::GetAttrs() {
|
||||||
|
input_size_ = inputs_shape_.size();
|
||||||
|
if (input_size_ != 2 && input_size_ != 3) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": inputs_shape_.size() must be 2 or 3, but got size " << input_size_;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MaskedFillInfo::InferTensorMap() {
|
||||||
|
if (ArithmeticBase::InferTensorMap() != SUCCESS) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_size_ == 3) {
|
||||||
|
// append a void tensor map for 0-dimensional tensor input 'value'
|
||||||
|
(void)inputs_tensor_map_.emplace_back(TensorMap());
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> MaskedFillInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
|
auto sp_vector = ArithmeticBase::GenerateOpStrategies(stage_id);
|
||||||
|
if (input_size_ == 3) {
|
||||||
|
// append void strategy for input `value`
|
||||||
|
for (size_t i = 0; i < sp_vector.size(); ++i) {
|
||||||
|
auto strategies = sp_vector[i]->GetInputDim();
|
||||||
|
(void)strategies.emplace_back(Dimensions());
|
||||||
|
sp_vector[i] = std::make_shared<Strategy>(stage_id, strategies);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sp_vector;
|
||||||
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -263,6 +263,31 @@ class LerpInfo : public ArithmeticBase {
|
||||||
private:
|
private:
|
||||||
size_t inputs_size_ = 0;
|
size_t inputs_size_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SquaredDifferenceInfo : public ArithmeticBase {
|
||||||
|
public:
|
||||||
|
SquaredDifferenceInfo(const std::string &name, const Shapes &input_shape, const Shapes &output_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: ArithmeticBase(name, input_shape, output_shape, attrs, std::make_shared<SquaredDifferenceCost>()) {}
|
||||||
|
~SquaredDifferenceInfo() override = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MaskedFillInfo : public ArithmeticBase {
|
||||||
|
public:
|
||||||
|
MaskedFillInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &output_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: ArithmeticBase(name, inputs_shape, output_shape, attrs, std::make_shared<MaskedFillCost>()) {}
|
||||||
|
~MaskedFillInfo() override = default;
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status GetAttrs() override;
|
||||||
|
Status InferTensorMap() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t input_size_ = 0;
|
||||||
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,218 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "frontend/parallel/ops_info/gamma_info.h"
|
||||||
|
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||||
|
#include "frontend/parallel/auto_parallel/edge_costmodel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
Status GammaInfo::InferAttrs() {
|
||||||
|
if (infer_attrs_completed_) {
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (GetAttrs() != SUCCESS) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
ResetInputsShape();
|
||||||
|
infer_attrs_completed_ = true;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GammaInfo::GetAttrs() {
|
||||||
|
seed_ = GetIntAttr(SEED);
|
||||||
|
if (seed_ < 0) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Seed must be greater or equal to zero, bug got " << seed_;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
seed2_ = GetIntAttr(SEED2);
|
||||||
|
if (seed2_ < 0) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Seed2 must be greater or equal to zero, bug got " << seed2_;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GammaInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
|
MS_EXCEPTION_IF_NULL(strategy);
|
||||||
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
auto strategies = strategy->GetInputDim();
|
||||||
|
auto alpha_strategy = strategies.at(1);
|
||||||
|
auto beta_strategy = strategies.at(2);
|
||||||
|
if (!IsNotSplittableStrategy(alpha_strategy) || !IsNotSplittableStrategy(beta_strategy)) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Cannot shard the input `alpha` and `beta`, but got strategy "
|
||||||
|
<< StrategyToString(strategies);
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GammaInfo::InferDevMatrixShape() {
|
||||||
|
dev_matrix_shape_.clear();
|
||||||
|
auto strategies = strategy_->GetInputDim();
|
||||||
|
auto shape_strategy = strategies.at(0);
|
||||||
|
dev_matrix_shape_ = shape_strategy;
|
||||||
|
|
||||||
|
MS_LOG(INFO) << name_ << ": The dev matrix is " << ShapeToString(dev_matrix_shape_);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GammaInfo::InferTensorMap() {
|
||||||
|
auto strategies = strategy_->GetInputDim();
|
||||||
|
auto shape_strategy = strategies.at(0);
|
||||||
|
size_t size = shape_strategy.size();
|
||||||
|
TensorMap shape_tensor_map;
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
shape_tensor_map.push_back(size - i - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs_tensor_map_.push_back(shape_tensor_map);
|
||||||
|
(void)inputs_tensor_map_.emplace_back(TensorMap(strategies.at(1).size(), -1));
|
||||||
|
(void)inputs_tensor_map_.emplace_back(TensorMap(strategies.at(2).size(), -1));
|
||||||
|
(void)outputs_tensor_map_.emplace_back(shape_tensor_map);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> GammaInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
|
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||||
|
Shape input1_split(inputs_shape_[1].size(), 0);
|
||||||
|
Shape input2_split(inputs_shape_[2].size(), 0);
|
||||||
|
Shapes splittable_inputs = {input0_split, input1_split, input2_split};
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> sp_vector;
|
||||||
|
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed.";
|
||||||
|
}
|
||||||
|
if (sp_vector.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": No available strategy.";
|
||||||
|
}
|
||||||
|
return sp_vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GammaInfo::UpdateShape(const CNodePtr &cnode) {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto input_node = cnode->input(1)->cast<ValueNodePtr>();
|
||||||
|
std::vector<int64_t> input_shape = GetValue<std::vector<int64_t>>(input_node->value());
|
||||||
|
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||||
|
for (size_t i = 0; i < stra[0].size(); i++) {
|
||||||
|
input_shape[i] /= stra[0][i];
|
||||||
|
}
|
||||||
|
auto func_graph = cnode->func_graph();
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto manager = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
ValuePtr new_shape = MakeValue(input_shape);
|
||||||
|
AnfNodePtr val = NewValueNode(new_shape);
|
||||||
|
(void)manager->Replace(cnode->input(1), val);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GammaInfo::ReplaceNodeInputOrAttrs() {
|
||||||
|
// Replace input 'shape' to slice shape
|
||||||
|
auto cnode = cnode_;
|
||||||
|
UpdateShape(cnode);
|
||||||
|
|
||||||
|
// Update seed according rank_id
|
||||||
|
int64_t rank_id = g_device_manager->rank_index_in_stage();
|
||||||
|
int64_t seed_bias;
|
||||||
|
if (repeated_num_in_dev_matrix_right_) {
|
||||||
|
seed_bias = rank_id / repeated_calc_num_;
|
||||||
|
} else {
|
||||||
|
int64_t device_num = stage_device_size_;
|
||||||
|
seed_bias = rank_id % (device_num / repeated_calc_num_);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
prim->set_attr(SEED, MakeValue(seed_ + seed_bias));
|
||||||
|
prim->set_attr(SEED2, MakeValue(seed2_ + seed_bias));
|
||||||
|
}
|
||||||
|
|
||||||
|
void GammaInfo::ResetInputsShape() {
|
||||||
|
if (inputs_shape_.size() == input_value_.size()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueTuplePtr shape_value = input_value_[0]->cast<ValueTuplePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(shape_value);
|
||||||
|
(void)inputs_shape_.insert(inputs_shape_.begin(), GetValue<Shape>(shape_value));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GammaInfo::IsNotSplittableStrategy(const Dimensions &strategy) {
|
||||||
|
return std::all_of(strategy.begin(), strategy.end(), [](int64_t val) { return val == 1; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void GammaInfo::ReComputeBatchSplitFlagList() {
|
||||||
|
ResetInputsShape();
|
||||||
|
split_flag_list_.clear();
|
||||||
|
split_flag_list_ = std::vector<bool>(inputs_shape_.size(), false);
|
||||||
|
if (!split_flag_list_.empty()) {
|
||||||
|
split_flag_list_[0] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t GammaInfo::ComputeOpAndPrevEdgeParameterInvolved() {
|
||||||
|
if (is_output_parameter_involve_ != -1) {
|
||||||
|
return is_output_parameter_involve_;
|
||||||
|
}
|
||||||
|
is_parameter_involve_ = is_parameter_;
|
||||||
|
const auto &prev_edges = this->GetAlivePrevEdges();
|
||||||
|
for (auto &p_edge : prev_edges) {
|
||||||
|
auto input_index = p_edge->next_op_input_index();
|
||||||
|
auto prev_op_para = p_edge->prev_operator()->ComputeOpAndPrevEdgeParameterInvolved();
|
||||||
|
if (input_index - 1 >= is_parameter_involve_.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << " has input length: " << is_parameter_involve_.size()
|
||||||
|
<< ", but got wrong input_index: " << input_index;
|
||||||
|
}
|
||||||
|
if (prev_op_para == 0) {
|
||||||
|
is_parameter_involve_[input_index - 1] = false;
|
||||||
|
} else if (prev_op_para == 1) {
|
||||||
|
is_parameter_involve_[input_index - 1] = true;
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << " got wrong value: " << prev_op_para << ", input_index: " << input_index;
|
||||||
|
}
|
||||||
|
p_edge->set_parameter_involve(prev_op_para);
|
||||||
|
}
|
||||||
|
if (std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; })) {
|
||||||
|
// If anyone of the input is a parameter_involved, the output is parameter_involved.
|
||||||
|
is_output_parameter_involve_ = 1;
|
||||||
|
} else {
|
||||||
|
is_output_parameter_involve_ = 0;
|
||||||
|
}
|
||||||
|
// Set 'is_parameter_involve_' and 'is_output_parameter_involve_' into operatorCost, which are used in
|
||||||
|
// calculating 'inputs_in_memory' and 'output_in_memory', respectively.
|
||||||
|
operator_cost()->set_is_parameter_involve(is_parameter_involve_);
|
||||||
|
operator_cost()->set_output_parameter_involve(is_output_parameter_involve_);
|
||||||
|
// Calculating 'output_in_memory'
|
||||||
|
operator_cost()->CalculateOutputInMemory();
|
||||||
|
// Calculating 'inputs_in_memory'
|
||||||
|
std::map<size_t, bool> input_in_memory;
|
||||||
|
for (auto &p_edge : prev_edges) {
|
||||||
|
auto input_index = p_edge->next_op_input_index();
|
||||||
|
auto is_in_mem = p_edge->prev_operator()->operator_cost()->is_output_in_memory();
|
||||||
|
(void)input_in_memory.emplace(std::make_pair(input_index, is_in_mem));
|
||||||
|
}
|
||||||
|
operator_cost()->CalculateInputsInMemory(input_in_memory);
|
||||||
|
|
||||||
|
return is_output_parameter_involve_;
|
||||||
|
}
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GAMMA_INFO_H_
|
||||||
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GAMMA_INFO_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "utils/hash_map.h"
|
||||||
|
#include "ir/value.h"
|
||||||
|
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||||
|
#include "frontend/parallel/ops_info/operator_info.h"
|
||||||
|
#include "frontend/parallel/strategy.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
class GammaInfo : public OperatorInfo {
|
||||||
|
public:
|
||||||
|
GammaInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GammaCost>()) {}
|
||||||
|
~GammaInfo() override = default;
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||||
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||||
|
void UpdateShape(const CNodePtr &cnode);
|
||||||
|
void ReplaceNodeInputOrAttrs() override;
|
||||||
|
void ReComputeBatchSplitFlagList() override;
|
||||||
|
int64_t ComputeOpAndPrevEdgeParameterInvolved() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status InferAttrs() override;
|
||||||
|
Status GetAttrs() override;
|
||||||
|
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||||
|
Status InferForwardCommunication() override { return SUCCESS; }
|
||||||
|
Status InferDevMatrixShape() override;
|
||||||
|
Status InferTensorMap() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void ResetInputsShape();
|
||||||
|
bool IsNotSplittableStrategy(const Dimensions &strategy);
|
||||||
|
|
||||||
|
int64_t seed_ = 0;
|
||||||
|
int64_t seed2_ = 0;
|
||||||
|
};
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GAMMA_INFO_H_
|
|
@ -0,0 +1,136 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "frontend/parallel/ops_info/kldiv_loss_info.h"
|
||||||
|
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
Status KLDivLossInfo::GetAttrs() {
|
||||||
|
reduction_ = GetStringAttr(REDUCTION);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status KLDivLossInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto strategies = strategy->GetInputDim();
|
||||||
|
if (strategies[0] != strategies[1]) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The strategy of 'logits' and 'labels' must be the same, but got strategy "
|
||||||
|
<< StrategyToString(strategies);
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status KLDivLossInfo::InferDevMatrixShape() {
|
||||||
|
dev_matrix_shape_.clear();
|
||||||
|
auto strategies = strategy_->GetInputDim();
|
||||||
|
auto logits_strategy = strategies.at(0);
|
||||||
|
dev_matrix_shape_ = logits_strategy;
|
||||||
|
|
||||||
|
MS_LOG(INFO) << name_ << ": dev matrix is " << ShapeToString(dev_matrix_shape_);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status KLDivLossInfo::InferTensorMap() {
|
||||||
|
inputs_tensor_map_.clear();
|
||||||
|
outputs_tensor_map_.clear();
|
||||||
|
|
||||||
|
TensorMap sub_tensor_map;
|
||||||
|
auto strategies = strategy_->GetInputDim();
|
||||||
|
auto logits_strategy = strategies.at(0);
|
||||||
|
size_t size = logits_strategy.size();
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
sub_tensor_map.push_back(size - i - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs_tensor_map_.push_back(sub_tensor_map);
|
||||||
|
inputs_tensor_map_.push_back(sub_tensor_map);
|
||||||
|
|
||||||
|
if (reduction_ == ATTR_NONE) {
|
||||||
|
(void)outputs_tensor_map_.emplace_back(sub_tensor_map);
|
||||||
|
} else {
|
||||||
|
(void)outputs_tensor_map_.emplace_back(TensorMap());
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> KLDivLossInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
|
Shape splittable_input(inputs_shape_[0].size());
|
||||||
|
std::iota(splittable_input.begin(), splittable_input.end(), 1);
|
||||||
|
Shapes splittable_inputs(inputs_shape_.size(), splittable_input);
|
||||||
|
std::vector<StrategyPtr> sp_vector;
|
||||||
|
if (GenerateStrategiesForDependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": Generate strategies for dependent inputs() failed.";
|
||||||
|
}
|
||||||
|
if (sp_vector.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": No available strategy.";
|
||||||
|
}
|
||||||
|
return sp_vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status KLDivLossInfo::InferGroup() {
|
||||||
|
Shape group_create_map;
|
||||||
|
if (repeated_calc_num_ > 1) {
|
||||||
|
if (repeated_num_in_dev_matrix_right_) {
|
||||||
|
group_create_map.push_back(0);
|
||||||
|
} else {
|
||||||
|
group_create_map.push_back(SizeToLong(dev_matrix_shape_.size()) - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (CreateGroupByTensorMap(group_create_map, &group_list_) != SUCCESS) {
|
||||||
|
ReportError(name_ + ": Create group failed.");
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status KLDivLossInfo::InferForwardCommunication() {
|
||||||
|
if (reduction_ == ATTR_NONE) {
|
||||||
|
MS_LOG(DEBUG) << name_ << ": reduction is " << reduction_ << ", there is no need to append reduce op.";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
if (InferGroup() != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Infer group failed";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (group_list_.empty()) {
|
||||||
|
MS_LOG(INFO) << name_ << ": Forward all reduce is not required";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
forward_op_.clear();
|
||||||
|
if (reduction_ == MEAN) {
|
||||||
|
auto element_type = outputs_dtype_->cast<mindspore::TensorTypePtr>()->element();
|
||||||
|
forward_op_ = CreateReduceMeanForwardOp(group_list_, element_type);
|
||||||
|
} else {
|
||||||
|
(void)forward_op_.emplace_back(CreateAllReduceOp(REDUCE_OP_SUM, group_list_[0].name()));
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list_[0].name();
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
void KLDivLossInfo::ReComputeBatchSplitFlagList() { split_flag_list_.assign(inputs_shape_.size(), true); }
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_KLDIV_LOSS_INFO_H_
|
||||||
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_KLDIV_LOSS_INFO_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "frontend/parallel/ops_info/operator_info.h"
|
||||||
|
#include "frontend/parallel/strategy.h"
|
||||||
|
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
class KLDivLossInfo : public OperatorInfo {
|
||||||
|
public:
|
||||||
|
KLDivLossInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<KLDivLossCost>()) {}
|
||||||
|
~KLDivLossInfo() override = default;
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||||
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||||
|
void ReComputeBatchSplitFlagList() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status GetAttrs() override;
|
||||||
|
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||||
|
Status InferDevMatrixShape() override;
|
||||||
|
Status InferTensorMap() override;
|
||||||
|
Status InferForwardCommunication() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status InferGroup();
|
||||||
|
|
||||||
|
std::string reduction_;
|
||||||
|
std::vector<Group> group_list_;
|
||||||
|
};
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_KLDIV_LOSS_INFO_H_
|
|
@ -0,0 +1,176 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "frontend/parallel/ops_info/lin_space_info.h"
|
||||||
|
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||||
|
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
Status LinSpaceInfo::GetAttrs() {
|
||||||
|
auto output_0_shape = outputs_shape_.at(0);
|
||||||
|
output_size_ = output_0_shape.at(0);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status LinSpaceInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
|
auto strategies = strategy->GetInputDim();
|
||||||
|
if (strategies.size() != 1 && strategies[0].size() != 1) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The shape of input_strategy must be [1, 1], but got strategy "
|
||||||
|
<< StrategyToString(strategies);
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto split_num = strategies[0][0];
|
||||||
|
if (output_size_ % split_num != 0) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategies) << ", output size is "
|
||||||
|
<< output_size_ << " cannot be divisible by strategy value " << split_num;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stage_device_size_ % split_num != 0) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategies)
|
||||||
|
<< ", the device size in this stage is " << stage_device_size_
|
||||||
|
<< " cannot be divisible by the strategy value " << split_num;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status LinSpaceInfo::InferDevMatrixShape() {
|
||||||
|
dev_matrix_shape_.clear();
|
||||||
|
auto strategies = strategy_->GetInputDim();
|
||||||
|
auto split_strategy = strategies.at(0);
|
||||||
|
dev_matrix_shape_.push_back(split_strategy.at(0));
|
||||||
|
MS_LOG(INFO) << name_ << ": The device matrix is " << ShapeToString(dev_matrix_shape_);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status LinSpaceInfo::InferTensorMap() {
|
||||||
|
inputs_tensor_map_.clear();
|
||||||
|
outputs_tensor_map_.clear();
|
||||||
|
|
||||||
|
inputs_tensor_map_.assign(2, TensorMap{});
|
||||||
|
(void)outputs_tensor_map_.emplace_back(TensorMap{0});
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> LinSpaceInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
|
Shape split_input(1, 1);
|
||||||
|
Shape split_shape(1, output_size_);
|
||||||
|
Shapes splittable_inputs = {split_input};
|
||||||
|
Shapes splittable_shapes = {split_shape};
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> sp_vector;
|
||||||
|
if (GenerateStrategiesForIndependentInputs(stage_id, splittable_shapes, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed.";
|
||||||
|
}
|
||||||
|
if (sp_vector.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": No available strategy.";
|
||||||
|
}
|
||||||
|
return sp_vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Strategys> LinSpaceInfo::GenerateBatchStrategies() {
|
||||||
|
if (InferAttrs() != SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t dev_num = g_device_manager->stage_device_num();
|
||||||
|
Strategys strategies = {Dimensions{dev_num}};
|
||||||
|
return std::make_shared<Strategys>(strategies);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t LinSpaceInfo::GetSplitNum() {
|
||||||
|
auto strategies = strategy_->GetInputDim();
|
||||||
|
auto split_strategy = strategies.at(0);
|
||||||
|
auto split_num = split_strategy.at(0);
|
||||||
|
return split_num;
|
||||||
|
}
|
||||||
|
|
||||||
|
ReplaceGraphPtr LinSpaceInfo::replace_graph(const CNodePtr &cnode) {
|
||||||
|
auto split_num = GetSplitNum();
|
||||||
|
if (split_num > 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
|
||||||
|
}
|
||||||
|
return replace_graph_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status LinSpaceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
|
auto split_num = GetSplitNum();
|
||||||
|
if (output_size_ % split_num != 0) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (split_num == 1) {
|
||||||
|
MS_LOG(INFO) << name_ << ": split num is 1, no need to replace graph";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
GenerateGraph gen_g = GenerateGraph(attrs_);
|
||||||
|
if (gen_g.Init(cnode) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "GenerateGraph Init failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (InferSliceId() != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Infer slice id failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t slice_output_size = output_size_ / split_num;
|
||||||
|
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), gen_g.virtual_input_node()});
|
||||||
|
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), sub});
|
||||||
|
AnfNodePtr interval = nullptr;
|
||||||
|
if (output_size_ == 2) {
|
||||||
|
interval = sub;
|
||||||
|
} else {
|
||||||
|
auto interval_divsor = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(output_size_ - 2), dtype});
|
||||||
|
interval = gen_g.PushBack({gen_g.NewOpInst(DIV), sub, interval_divsor});
|
||||||
|
}
|
||||||
|
|
||||||
|
// new_start = start + slice_id * slice_output_size * interval
|
||||||
|
// new_end = new_start + (slice_output_size - 1) * interval
|
||||||
|
// new_x = slice_output_size
|
||||||
|
auto offset_size = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(slice_id_ * slice_output_size), dtype});
|
||||||
|
auto start_offset = gen_g.PushBack({gen_g.NewOpInst(MUL), interval, offset_size});
|
||||||
|
auto new_start = gen_g.PushBack({gen_g.NewOpInst(ADD), gen_g.virtual_input_node(), start_offset});
|
||||||
|
auto end_start_offset_size = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(slice_output_size - 1), dtype});
|
||||||
|
auto start_end_offset = gen_g.PushBack({gen_g.NewOpInst(MUL), interval, end_start_offset_size});
|
||||||
|
auto new_end = gen_g.PushBack({gen_g.NewOpInst(ADD), new_start, start_end_offset});
|
||||||
|
auto lin_space = gen_g.PushBack({gen_g.NewOpInst(LIN_SPACE), new_start, new_end, CreatInt64Imm(slice_output_size)});
|
||||||
|
|
||||||
|
std::vector<std::pair<AnfNodePtr, int64_t>> inputs_nodes = {std::make_pair(sub, 1), std::make_pair(sub, 2),
|
||||||
|
std::make_pair(new_start, 1)};
|
||||||
|
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
|
||||||
|
std::make_pair(inputs_nodes, lin_space));
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status LinSpaceInfo::InferSliceId() {
|
||||||
|
CheckGlobalDeviceManager();
|
||||||
|
int64_t rank = g_device_manager->rank_index_in_stage();
|
||||||
|
slice_id_ = rank;
|
||||||
|
if (repeated_calc_num_ > 1) {
|
||||||
|
if (repeated_num_in_dev_matrix_right_) {
|
||||||
|
slice_id_ /= dev_matrix_shape_.back();
|
||||||
|
} else {
|
||||||
|
slice_id_ %= dev_matrix_shape_.front();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_LIN_SPACE_INFO_H_
|
||||||
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_LIN_SPACE_INFO_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "utils/hash_map.h"
|
||||||
|
#include "ir/value.h"
|
||||||
|
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||||
|
#include "frontend/parallel/ops_info/operator_info.h"
|
||||||
|
#include "frontend/parallel/strategy.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
class LinSpaceInfo : public OperatorInfo {
|
||||||
|
public:
|
||||||
|
LinSpaceInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<LinSpaceCost>()) {}
|
||||||
|
~LinSpaceInfo() override = default;
|
||||||
|
|
||||||
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||||
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||||
|
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
||||||
|
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status GetAttrs() override;
|
||||||
|
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||||
|
Status InferDevMatrixShape() override;
|
||||||
|
Status InferTensorMap() override;
|
||||||
|
Status InferForwardCommunication() override { return SUCCESS; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status InferSliceId();
|
||||||
|
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||||
|
int64_t GetSplitNum();
|
||||||
|
|
||||||
|
int64_t slice_id_ = 0;
|
||||||
|
int64_t output_size_ = 0;
|
||||||
|
};
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif
|
|
@ -2176,5 +2176,29 @@ AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple
|
||||||
auto tensor_node = NewValueNode(tensor_ptr);
|
auto tensor_node = NewValueNode(tensor_ptr);
|
||||||
return tensor_node->cast<AnfNodePtr>();
|
return tensor_node->cast<AnfNodePtr>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) {
|
||||||
|
// Create AllReduceSum op
|
||||||
|
Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name());
|
||||||
|
std::string group_name = forward_group[0].name();
|
||||||
|
MS_LOG(INFO) << "The group of forward all reduce is " << group_name;
|
||||||
|
|
||||||
|
// Create RealDiv op
|
||||||
|
OperatorName operator1_name = REAL_DIV;
|
||||||
|
std::vector<Device> device_list = forward_group[0].GetDevicesList();
|
||||||
|
auto divisor = static_cast<float>(device_list.size());
|
||||||
|
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(divisor, dtype);
|
||||||
|
ValuePtr op1_param_value = MakeValue(tensor_ptr);
|
||||||
|
Attr op1_param = std::make_pair("divisor", op1_param_value);
|
||||||
|
OperatorParams operator1_params = {std::make_pair(op1_param, 2)};
|
||||||
|
OperatorAttrs operator1_attrs;
|
||||||
|
OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params);
|
||||||
|
Operator op1 = std::make_pair(operator1_name, operator1_args);
|
||||||
|
ForwardOp forward_op = {op0, op1};
|
||||||
|
|
||||||
|
std::string dtype_name = dtype->ToString();
|
||||||
|
MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name;
|
||||||
|
return forward_op;
|
||||||
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -112,7 +112,7 @@ class OperatorInfo {
|
||||||
// In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated
|
// In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated
|
||||||
// by the output
|
// by the output
|
||||||
Status CalculateMemoryCostForInference();
|
Status CalculateMemoryCostForInference();
|
||||||
int64_t ComputeOpAndPrevEdgeParameterInvolved();
|
virtual int64_t ComputeOpAndPrevEdgeParameterInvolved();
|
||||||
|
|
||||||
ForwardOp forward_op() const { return forward_op_; }
|
ForwardOp forward_op() const { return forward_op_; }
|
||||||
ForwardOp replace_op() const { return replace_op_; }
|
ForwardOp replace_op() const { return replace_op_; }
|
||||||
|
@ -227,7 +227,7 @@ class OperatorInfo {
|
||||||
void SetRepeatedCalcDevMatrix();
|
void SetRepeatedCalcDevMatrix();
|
||||||
void ResetTensorMapIfRepeatedCalc();
|
void ResetTensorMapIfRepeatedCalc();
|
||||||
Status CreateGroupByDim(size_t axis, std::vector<Group> *group);
|
Status CreateGroupByDim(size_t axis, std::vector<Group> *group);
|
||||||
Status InferAttrs();
|
virtual Status InferAttrs();
|
||||||
void ResetQueueMember();
|
void ResetQueueMember();
|
||||||
Status InitWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
|
Status InitWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
|
||||||
Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
|
Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy);
|
||||||
|
@ -374,6 +374,8 @@ ValuePtr MakeListValue(const std::vector<int64_t> &v);
|
||||||
ValuePtr MakeTupleListValue(const Shapes &v);
|
ValuePtr MakeTupleListValue(const Shapes &v);
|
||||||
AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector<int64_t> &value_tuple);
|
AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector<int64_t> &value_tuple);
|
||||||
AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple);
|
AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple);
|
||||||
|
|
||||||
|
ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype);
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -70,5 +70,8 @@
|
||||||
#include "frontend/parallel/ops_info/addn_info.h"
|
#include "frontend/parallel/ops_info/addn_info.h"
|
||||||
#include "frontend/parallel/ops_info/inplace_add_info.h"
|
#include "frontend/parallel/ops_info/inplace_add_info.h"
|
||||||
#include "frontend/parallel/ops_info/cdist_info.h"
|
#include "frontend/parallel/ops_info/cdist_info.h"
|
||||||
|
#include "frontend/parallel/ops_info/gamma_info.h"
|
||||||
|
#include "frontend/parallel/ops_info/kldiv_loss_info.h"
|
||||||
|
#include "frontend/parallel/ops_info/lin_space_info.h"
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
||||||
|
|
|
@ -260,6 +260,9 @@ constexpr char POOLED_WIDTH[] = "pooled_width";
|
||||||
constexpr char SPATIAL_SCALE[] = "spatial_scale";
|
constexpr char SPATIAL_SCALE[] = "spatial_scale";
|
||||||
constexpr char SAMPLE_NUM[] = "sample_num";
|
constexpr char SAMPLE_NUM[] = "sample_num";
|
||||||
constexpr char ROI_END_MODE[] = "roi_end_mode";
|
constexpr char ROI_END_MODE[] = "roi_end_mode";
|
||||||
|
constexpr char REDUCTION[] = "reduction";
|
||||||
|
constexpr char MEAN[] = "mean";
|
||||||
|
constexpr char ATTR_NONE[] = "none";
|
||||||
|
|
||||||
// Operator
|
// Operator
|
||||||
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
||||||
|
@ -477,6 +480,12 @@ constexpr char INPLACE_SUB[] = "InplaceSub";
|
||||||
constexpr char CDIST[] = "Cdist";
|
constexpr char CDIST[] = "Cdist";
|
||||||
constexpr char L2_LOSS[] = "L2Loss";
|
constexpr char L2_LOSS[] = "L2Loss";
|
||||||
constexpr char LERP[] = "Lerp";
|
constexpr char LERP[] = "Lerp";
|
||||||
|
constexpr char SQUARED_DIFFERENCE[] = "SquaredDifference";
|
||||||
|
constexpr char ERFINV[] = "Erfinv";
|
||||||
|
constexpr char SPLITV[] = "SplitV";
|
||||||
|
constexpr char GAMMA[] = "Gamma";
|
||||||
|
constexpr char KLDIV_LOSS[] = "KLDivLoss";
|
||||||
|
constexpr char LIN_SPACE[] = "LinSpace";
|
||||||
|
|
||||||
// pipeline
|
// pipeline
|
||||||
constexpr size_t PIPELINE_FUSTION_OFFSET = 100;
|
constexpr size_t PIPELINE_FUSTION_OFFSET = 100;
|
||||||
|
|
|
@ -201,30 +201,6 @@ Status ReduceMethod::InferForwardCommunication() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) {
|
|
||||||
// Create AllReduceSum op
|
|
||||||
Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name());
|
|
||||||
std::string group_name = forward_group[0].name();
|
|
||||||
MS_LOG(INFO) << "The group of forward all reduce is " << group_name;
|
|
||||||
|
|
||||||
// Create RealDiv op
|
|
||||||
OperatorName operator1_name = REAL_DIV;
|
|
||||||
std::vector<Device> device_list = forward_group[0].GetDevicesList();
|
|
||||||
auto divisor = static_cast<float>(device_list.size());
|
|
||||||
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(divisor, dtype);
|
|
||||||
ValuePtr op1_param_value = MakeValue(tensor_ptr);
|
|
||||||
Attr op1_param = std::make_pair("divisor", op1_param_value);
|
|
||||||
OperatorParams operator1_params = {std::make_pair(op1_param, 2)};
|
|
||||||
OperatorAttrs operator1_attrs;
|
|
||||||
OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params);
|
|
||||||
Operator op1 = std::make_pair(operator1_name, operator1_args);
|
|
||||||
ForwardOp forward_op = {op0, op1};
|
|
||||||
|
|
||||||
std::string dtype_name = dtype->ToString();
|
|
||||||
MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name;
|
|
||||||
return forward_op;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ReduceMeanInfo::InferForwardCommunication() {
|
Status ReduceMeanInfo::InferForwardCommunication() {
|
||||||
auto strategies = strategy_->GetInputDim();
|
auto strategies = strategy_->GetInputDim();
|
||||||
Dimensions stra = strategies.at(0);
|
Dimensions stra = strategies.at(0);
|
||||||
|
@ -801,7 +777,7 @@ Status SquareSumAllInfo::InferGroup() {
|
||||||
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
|
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
|
||||||
// it need to handle the first dimension of map.
|
// it need to handle the first dimension of map.
|
||||||
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
|
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
|
||||||
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
|
group_creat_map.push_back(SizeToLong(dev_matrix_shape_.size() - size_t(1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t index = 0; index < size; ++index) {
|
for (size_t index = 0; index < size; ++index) {
|
||||||
|
|
|
@ -30,9 +30,9 @@ namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
Status SplitInfo::GetAttrs() {
|
Status SplitInfo::GetAttrs() {
|
||||||
int64_t axis = 0;
|
int64_t axis = 0;
|
||||||
int64_t output_num = 0;
|
|
||||||
|
|
||||||
auto axis_iter = attrs_.find(AXIS);
|
std::string attr_axis_name = GetSplitAxisAttrName();
|
||||||
|
auto axis_iter = attrs_.find(attr_axis_name);
|
||||||
if (axis_iter != attrs_.end()) {
|
if (axis_iter != attrs_.end()) {
|
||||||
MS_EXCEPTION_IF_NULL(axis_iter->second);
|
MS_EXCEPTION_IF_NULL(axis_iter->second);
|
||||||
if (axis_iter->second->isa<Int64Imm>()) {
|
if (axis_iter->second->isa<Int64Imm>()) {
|
||||||
|
@ -56,21 +56,6 @@ Status SplitInfo::GetAttrs() {
|
||||||
}
|
}
|
||||||
axis_ = LongToSize(axis);
|
axis_ = LongToSize(axis);
|
||||||
|
|
||||||
auto output_num_iter = attrs_.find(OUTPUT_NUM);
|
|
||||||
if (output_num_iter != attrs_.end()) {
|
|
||||||
MS_EXCEPTION_IF_NULL(output_num_iter->second);
|
|
||||||
if (output_num_iter->second->isa<Int64Imm>()) {
|
|
||||||
output_num = output_num_iter->second->cast<Int64ImmPtr>()->value();
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << name_ << ": The value of output_num is not int";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << name_ << ": Can not find the output_num attr";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
output_num_ = LongToSize(output_num);
|
|
||||||
|
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,6 +137,9 @@ std::vector<StrategyPtr> SplitInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
|
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";
|
MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";
|
||||||
}
|
}
|
||||||
|
if (sp_vector.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": No available strategy";
|
||||||
|
}
|
||||||
|
|
||||||
return sp_vector;
|
return sp_vector;
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,10 +46,21 @@ class SplitInfo : public OperatorInfo {
|
||||||
Status InferDevMatrixShape() override;
|
Status InferDevMatrixShape() override;
|
||||||
Status InferTensorMap() override;
|
Status InferTensorMap() override;
|
||||||
Status InferAsLossDivisor() override;
|
Status InferAsLossDivisor() override;
|
||||||
|
virtual std::string GetSplitAxisAttrName() const { return AXIS; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t axis_ = 0;
|
size_t axis_ = 0;
|
||||||
size_t output_num_ = 0;
|
};
|
||||||
|
|
||||||
|
class SplitVInfo : public SplitInfo {
|
||||||
|
public:
|
||||||
|
SplitVInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: SplitInfo(operator_name, inputs_shape, outputs_shape, attrs) {}
|
||||||
|
~SplitVInfo() override = default;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::string GetSplitAxisAttrName() const override { return SPLIT_DIM; };
|
||||||
};
|
};
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -28,6 +28,19 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
|
Status UniformRealInfo::InferAttrs() {
|
||||||
|
if (infer_attrs_completed_) {
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (GetAttrs() != SUCCESS) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
ResetInputsShape();
|
||||||
|
infer_attrs_completed_ = true;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
Status UniformRealInfo::GetAttrs() {
|
Status UniformRealInfo::GetAttrs() {
|
||||||
seed_ = GetIntAttr(SEED);
|
seed_ = GetIntAttr(SEED);
|
||||||
if (seed_ < 0) {
|
if (seed_ < 0) {
|
||||||
|
@ -39,10 +52,6 @@ Status UniformRealInfo::GetAttrs() {
|
||||||
MS_LOG(ERROR) << name_ << ": Seed2 must be greater or equal to zero, bug got " << seed2_;
|
MS_LOG(ERROR) << name_ << ": Seed2 must be greater or equal to zero, bug got " << seed2_;
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
ValueTuplePtr shape_value = input_value_[0]->cast<ValueTuplePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(shape_value);
|
|
||||||
inputs_shape_.push_back(GetValue<Shape>(shape_value));
|
|
||||||
|
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,7 +106,10 @@ std::vector<StrategyPtr> UniformRealInfo::GenerateOpStrategies(int64_t stage_id)
|
||||||
|
|
||||||
std::vector<StrategyPtr> sp_vector;
|
std::vector<StrategyPtr> sp_vector;
|
||||||
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
|
MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed.";
|
||||||
|
}
|
||||||
|
if (sp_vector.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": No available strategy.";
|
||||||
}
|
}
|
||||||
return sp_vector;
|
return sp_vector;
|
||||||
}
|
}
|
||||||
|
@ -120,23 +132,29 @@ void UniformRealInfo::UpdateShape(const CNodePtr &cnode) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void UniformRealInfo::ReplaceNodeInputOrAttrs() {
|
void UniformRealInfo::ReplaceNodeInputOrAttrs() {
|
||||||
|
// Replace input 'shape' to slice shape
|
||||||
auto cnode = cnode_;
|
auto cnode = cnode_;
|
||||||
int64_t split_num = 1;
|
|
||||||
UpdateShape(cnode);
|
UpdateShape(cnode);
|
||||||
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
|
||||||
for (size_t i = 0; i < stra[0].size(); i++) {
|
|
||||||
split_num *= stra[0][i];
|
|
||||||
}
|
|
||||||
int64_t device_num = stage_device_size_;
|
|
||||||
int64_t rank_id = g_device_manager->rank_index_in_stage();
|
|
||||||
|
|
||||||
if (device_num != split_num) {
|
// Update seed according rank_id
|
||||||
auto split_group_num = device_num / split_num;
|
int64_t rank_id = g_device_manager->rank_index_in_stage();
|
||||||
int64_t seed_bias = rank_id / split_group_num;
|
int64_t seed_bias;
|
||||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
if (repeated_num_in_dev_matrix_right_) {
|
||||||
prim->set_attr(SEED, MakeValue(seed_ + seed_bias));
|
seed_bias = rank_id / repeated_calc_num_;
|
||||||
prim->set_attr(SEED2, MakeValue(seed2_ + seed_bias));
|
} else {
|
||||||
|
int64_t device_num = stage_device_size_;
|
||||||
|
seed_bias = rank_id % (device_num / repeated_calc_num_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
prim->set_attr(SEED, MakeValue(seed_ + seed_bias));
|
||||||
|
prim->set_attr(SEED2, MakeValue(seed2_ + seed_bias));
|
||||||
|
}
|
||||||
|
|
||||||
|
void UniformRealInfo::ResetInputsShape() {
|
||||||
|
ValueTuplePtr shape_value = input_value_[0]->cast<ValueTuplePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(shape_value);
|
||||||
|
inputs_shape_.push_back(GetValue<Shape>(shape_value));
|
||||||
}
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORMREAL_INFO_H_
|
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORM_REAL_INFO_H_
|
||||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORMREAL_INFO_H_
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORM_REAL_INFO_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -33,7 +33,7 @@ class UniformRealInfo : public OperatorInfo {
|
||||||
public:
|
public:
|
||||||
UniformRealInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
UniformRealInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
const PrimitiveAttrs &attrs)
|
const PrimitiveAttrs &attrs)
|
||||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<MaxPoolCost>()) {}
|
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<UniformRealCost>()) {}
|
||||||
~UniformRealInfo() override = default;
|
~UniformRealInfo() override = default;
|
||||||
|
|
||||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||||
|
@ -42,6 +42,7 @@ class UniformRealInfo : public OperatorInfo {
|
||||||
void ReplaceNodeInputOrAttrs() override;
|
void ReplaceNodeInputOrAttrs() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
Status InferAttrs() override;
|
||||||
Status GetAttrs() override;
|
Status GetAttrs() override;
|
||||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||||
Status InferForwardCommunication() override { return SUCCESS; }
|
Status InferForwardCommunication() override { return SUCCESS; }
|
||||||
|
@ -49,11 +50,12 @@ class UniformRealInfo : public OperatorInfo {
|
||||||
Status InferTensorMap() override;
|
Status InferTensorMap() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void ResetInputsShape();
|
||||||
|
|
||||||
int64_t seed_ = 0;
|
int64_t seed_ = 0;
|
||||||
int64_t seed2_ = 0;
|
int64_t seed2_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORMREAL_INFO_H_
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNIFORM_REAL_INFO_H_
|
||||||
|
|
|
@ -71,8 +71,9 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
max_size_strategy_dim_ = i;
|
max_size_strategy_dim_ = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (std::find(stra[max_size_strategy_dim_].begin(), stra[max_size_strategy_dim_].end(), shard_num_) ==
|
if (!stra[max_size_strategy_dim_].empty() &&
|
||||||
stra[max_size_strategy_dim_].end()) {
|
std::find(stra[max_size_strategy_dim_].begin(), stra[max_size_strategy_dim_].end(), shard_num_) ==
|
||||||
|
stra[max_size_strategy_dim_].end()) {
|
||||||
MS_LOG(ERROR) << name_
|
MS_LOG(ERROR) << name_
|
||||||
<< ": For each dataset input, the shard strategy can be not shard, "
|
<< ": For each dataset input, the shard strategy can be not shard, "
|
||||||
"or shard in one dim with the same shard size between each input."
|
"or shard in one dim with the same shard size between each input."
|
||||||
|
@ -96,7 +97,7 @@ Status VirtualDatasetInfo::InferForwardCommunication() { return SUCCESS; }
|
||||||
Status VirtualDatasetInfo::InferTensorMap() {
|
Status VirtualDatasetInfo::InferTensorMap() {
|
||||||
auto dev_mat_origin = strategy_->GetInputDim()[max_size_strategy_dim_];
|
auto dev_mat_origin = strategy_->GetInputDim()[max_size_strategy_dim_];
|
||||||
auto slice_dim_iter = std::find(dev_mat_origin.begin(), dev_mat_origin.end(), shard_num_);
|
auto slice_dim_iter = std::find(dev_mat_origin.begin(), dev_mat_origin.end(), shard_num_);
|
||||||
if (slice_dim_iter == dev_mat_origin.end()) {
|
if (!dev_mat_origin.empty() && slice_dim_iter == dev_mat_origin.end()) {
|
||||||
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim.";
|
MS_LOG(ERROR) << name_ << ": The dataset shard strategy only support shard in one dim.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
@ -108,8 +109,7 @@ Status VirtualDatasetInfo::InferTensorMap() {
|
||||||
if (dim == 1) {
|
if (dim == 1) {
|
||||||
tensor_map_index.push_back(MAP_NONE);
|
tensor_map_index.push_back(MAP_NONE);
|
||||||
} else if (dim == shard_num_) {
|
} else if (dim == shard_num_) {
|
||||||
if (repeated_num_in_dev_matrix_right_ && dev_matrix_shape_.size() != dev_mat_origin.size() &&
|
if (repeated_calc_num_ > 1 && repeated_num_in_dev_matrix_right_ && is_auto_parallel_) {
|
||||||
is_auto_parallel_) {
|
|
||||||
tensor_map_index.push_back(dev_mat_origin.size() - slice_dim);
|
tensor_map_index.push_back(dev_mat_origin.size() - slice_dim);
|
||||||
} else {
|
} else {
|
||||||
tensor_map_index.push_back(dev_mat_origin.size() - 1 - slice_dim);
|
tensor_map_index.push_back(dev_mat_origin.size() - 1 - slice_dim);
|
||||||
|
@ -176,8 +176,10 @@ Status VirtualDatasetInfo::GenerateStrategies(int64_t stage_id) {
|
||||||
}
|
}
|
||||||
for (auto &shape : inputs_shape_) {
|
for (auto &shape : inputs_shape_) {
|
||||||
Shape temp;
|
Shape temp;
|
||||||
temp.emplace_back(SizeToLong(total_dev_num));
|
if (!shape.empty()) {
|
||||||
(void)temp.insert(temp.end(), shape.size() - 1, 1);
|
temp.emplace_back(SizeToLong(total_dev_num));
|
||||||
|
(void)temp.insert(temp.end(), shape.size() - 1, 1);
|
||||||
|
}
|
||||||
strategy.push_back(temp);
|
strategy.push_back(temp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -179,7 +179,7 @@ bool IsSplittableOperator(const std::string &op_name) {
|
||||||
RESIZE_NEAREST_NEIGHBOR, CUM_SUM, FAST_GELU, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE,
|
RESIZE_NEAREST_NEIGHBOR, CUM_SUM, FAST_GELU, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE,
|
||||||
ROI_ALIGN, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH, SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND,
|
ROI_ALIGN, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH, SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND,
|
||||||
BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, L2_LOSS, LERP, ADDN,
|
BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, L2_LOSS, LERP, ADDN,
|
||||||
CDIST};
|
CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE};
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
auto iter = splittable_op.find(op_name);
|
auto iter = splittable_op.find(op_name);
|
||||||
|
|
|
@ -1809,18 +1809,19 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
|
||||||
std::vector<ValuePtr> elements;
|
std::vector<ValuePtr> elements;
|
||||||
for (size_t i = 0; i < shape_list[0].size(); i++) {
|
for (size_t i = 0; i < shape_list[0].size(); i++) {
|
||||||
if (shape_list[0][i].empty()) {
|
if (shape_list[0][i].empty()) {
|
||||||
MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero";
|
(void)elements.emplace_back(MakeValue(Dimensions()));
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
Dimensions input_strategy;
|
Dimensions input_strategy;
|
||||||
if (!shape_list[0][i].empty() && shape_list[0][i][0] % dev_num == 0) {
|
if (shape_list[0][i][0] % dev_num == 0) {
|
||||||
input_strategy.push_back(dev_num);
|
input_strategy.push_back(dev_num);
|
||||||
} else if (!shape_list[0][i].empty()) {
|
} else {
|
||||||
input_strategy.push_back(1);
|
input_strategy.push_back(1);
|
||||||
}
|
}
|
||||||
for (size_t j = 1; j < shape_list[0][i].size(); j++) {
|
for (size_t j = 1; j < shape_list[0][i].size(); j++) {
|
||||||
input_strategy.push_back(1);
|
input_strategy.push_back(1);
|
||||||
}
|
}
|
||||||
elements.push_back(MakeValue(input_strategy));
|
(void)elements.emplace_back(MakeValue(input_strategy));
|
||||||
}
|
}
|
||||||
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
|
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
|
||||||
attrs_temp[IN_STRATEGY] = strategy;
|
attrs_temp[IN_STRATEGY] = strategy;
|
||||||
|
|
|
@ -1110,3 +1110,89 @@ def test_matmul_xlogy_broadcast():
|
||||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||||
compile_net(net, x, y, b)
|
compile_net(net, x, y, b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_matmul_squared_difference_broadcast():
|
||||||
|
"""
|
||||||
|
Feature: distribute operator SquaredDifference in auto parallel.
|
||||||
|
Description: mul-SquaredDifference net with strategy in semi auto parallel.
|
||||||
|
Expectation: compile done without error.
|
||||||
|
"""
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, strategy1, strategy2):
|
||||||
|
super().__init__()
|
||||||
|
self.matmul = P.MatMul().shard(strategy1)
|
||||||
|
self.squared_difference = P.SquaredDifference().shard(strategy2)
|
||||||
|
|
||||||
|
def construct(self, x, y, b):
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
out = self.squared_difference(out, b)
|
||||||
|
return out
|
||||||
|
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
strategy1 = ((2, 4), (4, 1))
|
||||||
|
strategy2 = ((4, 1), (1, 2))
|
||||||
|
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||||
|
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||||
|
compile_net(net, x, y, b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_matmul_masked_fill_broadcast_with_value_float():
|
||||||
|
"""
|
||||||
|
Feature: distribute operator MaskedFill in auto parallel.
|
||||||
|
Description: mul-MaskedFill net with strategy in semi auto parallel.
|
||||||
|
Expectation: compile done without error.
|
||||||
|
"""
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, strategy1, strategy2):
|
||||||
|
super().__init__()
|
||||||
|
self.matmul = P.MatMul().shard(strategy1)
|
||||||
|
self.masked_fill = P.MaskedFill().shard(strategy2)
|
||||||
|
self.value = 1.0
|
||||||
|
|
||||||
|
def construct(self, x, y, b):
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
out = self.masked_fill(out, b, self.value)
|
||||||
|
return out
|
||||||
|
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
strategy1 = ((2, 4), (4, 1))
|
||||||
|
strategy2 = ((4, 1), (1, 2))
|
||||||
|
net = Net(strategy1, strategy2)
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||||
|
b = Tensor(np.ones([1, 64]), dtype=ms.bool_)
|
||||||
|
compile_net(net, x, y, b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_matmul_masked_fill_broadcast_with_value_tensor():
|
||||||
|
"""
|
||||||
|
Feature: distribute operator MaskedFill in auto parallel.
|
||||||
|
Description: mul-MaskedFill net with strategy in semi auto parallel.
|
||||||
|
Expectation: compile done without error.
|
||||||
|
"""
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, strategy1, strategy2):
|
||||||
|
super().__init__()
|
||||||
|
self.matmul = P.MatMul().shard(strategy1)
|
||||||
|
self.masked_fill = P.MaskedFill().shard(strategy2)
|
||||||
|
self.value = Tensor(1.0, ms.float32)
|
||||||
|
|
||||||
|
def construct(self, x, y, b):
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
out = self.masked_fill(out, b, self.value)
|
||||||
|
return out
|
||||||
|
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
strategy1 = ((2, 4), (4, 1))
|
||||||
|
strategy2 = ((4, 1), (1, 2), ())
|
||||||
|
net = Net(strategy1, strategy2)
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||||
|
b = Tensor(np.ones([1, 64]), dtype=ms.bool_)
|
||||||
|
compile_net(net, x, y, b)
|
||||||
|
|
|
@ -1173,3 +1173,34 @@ def test_matmul_soft_shrink():
|
||||||
y = Tensor(np.random.uniform(-5, 5, size=(32, 64)), dtype=ms.float32)
|
y = Tensor(np.random.uniform(-5, 5, size=(32, 64)), dtype=ms.float32)
|
||||||
b = Tensor(np.random.uniform(-5, 5, size=(64, 64)), dtype=ms.float32)
|
b = Tensor(np.random.uniform(-5, 5, size=(64, 64)), dtype=ms.float32)
|
||||||
compile_net(net, x, y, b)
|
compile_net(net, x, y, b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_matmul_erfinv():
|
||||||
|
"""
|
||||||
|
Feature: distribute operator Erfinv in auto parallel.
|
||||||
|
Description: matmul-squared_difference-matmul net with strategy in semi auto parallel.
|
||||||
|
Expectation: compile done without error.
|
||||||
|
"""
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, strategy1, strategy2):
|
||||||
|
super().__init__()
|
||||||
|
self.matmul = P.MatMul().shard(strategy1)
|
||||||
|
self.erfinv = P.Erfinv().shard(strategy2)
|
||||||
|
self.matmul2 = P.MatMul().shard(strategy1)
|
||||||
|
|
||||||
|
def construct(self, x, y, b):
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
out = self.erfinv(out)
|
||||||
|
out = self.matmul2(out, b)
|
||||||
|
return out
|
||||||
|
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((2, 2), (2, 2))
|
||||||
|
strategy2 = ((4, 2),)
|
||||||
|
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||||
|
|
||||||
|
x = Tensor(np.random.uniform(-5, 5, size=(128, 32)), dtype=ms.float32)
|
||||||
|
y = Tensor(np.random.uniform(-5, 5, size=(32, 64)), dtype=ms.float32)
|
||||||
|
b = Tensor(np.random.uniform(-5, 5, size=(64, 64)), dtype=ms.float32)
|
||||||
|
compile_net(net, x, y, b)
|
||||||
|
|
|
@ -0,0 +1,94 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import Tensor, context
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
from parallel.utils.utils import ParallelValidator, compile_net
|
||||||
|
|
||||||
|
SEED_ = 1
|
||||||
|
SEED2_ = 1
|
||||||
|
alpha_ = Tensor(np.array([1.0]), mstype.float32)
|
||||||
|
beta_ = Tensor(np.array([1.0]), mstype.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self, seed, seed2, strategy=None):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.gamma = P.Gamma(seed, seed2).shard(strategy)
|
||||||
|
|
||||||
|
def construct(self, shape, alpha, beta):
|
||||||
|
out = self.gamma(shape, alpha, beta)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def test_gamma_auto_parallel():
|
||||||
|
"""
|
||||||
|
Features: test Gamma auto parallel
|
||||||
|
Description: auto parallel
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
|
||||||
|
net = Net(SEED_, SEED2_)
|
||||||
|
shape = (4, 4, 4)
|
||||||
|
compile_net(net, shape, alpha_, beta_)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gamma_data_parallel():
|
||||||
|
"""
|
||||||
|
Features: test Gamma data parallel
|
||||||
|
Description: data parallel
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
|
||||||
|
net = Net(SEED_, SEED2_)
|
||||||
|
shape = (8, 8)
|
||||||
|
phase = compile_net(net, shape, alpha_, beta_)
|
||||||
|
|
||||||
|
validator = ParallelValidator(net, phase)
|
||||||
|
assert validator.check_node_attrs("Gamma-0", {"seed": 2, "seed2": 2})
|
||||||
|
|
||||||
|
|
||||||
|
def test_gamma_model_parallel():
|
||||||
|
"""
|
||||||
|
Features: test Gamma model parallel
|
||||||
|
Description: model parallel
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
|
||||||
|
shape = (8, 8)
|
||||||
|
strategy = ((2, 2), (1,), (1,))
|
||||||
|
net = Net(SEED_, SEED2_, strategy)
|
||||||
|
phase = compile_net(net, shape, alpha_, beta_)
|
||||||
|
validator = ParallelValidator(net, phase)
|
||||||
|
assert validator.check_node_attrs("Gamma-0", {"seed": 3, "seed2": 3})
|
||||||
|
|
||||||
|
|
||||||
|
def test_gamma_strategy_error():
|
||||||
|
"""
|
||||||
|
Features:test Gamma strategy error
|
||||||
|
Description: invalid strategy
|
||||||
|
Expectation: Raise RuntimeError
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
shape = (8, 8)
|
||||||
|
strategy = ((2, 2), (2,), (1,))
|
||||||
|
net = Net(SEED_, SEED2_, strategy)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net, shape, alpha_, beta_)
|
|
@ -0,0 +1,127 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import Tensor, context
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
from parallel.utils.utils import ParallelValidator, compile_net
|
||||||
|
|
||||||
|
logits_ = Tensor(np.random.uniform(0, 1, [8, 8]), mstype.float32)
|
||||||
|
labels_ = Tensor(np.random.randint(0, 10, [8, 8]), mstype.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self, reduction, strategy=None):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.kldiv_loss = P.KLDivLoss(reduction).shard(strategy)
|
||||||
|
|
||||||
|
def construct(self, logits, labels):
|
||||||
|
out = self.kldiv_loss(logits, labels)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def test_kldiv_loss_mean_auto_parallel():
|
||||||
|
"""
|
||||||
|
Features: test KLDivLoss auto parallel
|
||||||
|
Description: auto parallel, reduction is 'mean'
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
|
||||||
|
reduction = 'mean'
|
||||||
|
net = Net(reduction)
|
||||||
|
compile_net(net, logits_, labels_)
|
||||||
|
|
||||||
|
|
||||||
|
def test_kldiv_loss_none_auto_parallel():
|
||||||
|
"""
|
||||||
|
Features: test KLDivLoss auto parallel
|
||||||
|
Description: auto parallel, reduction is 'none'
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
|
||||||
|
reduction = 'none'
|
||||||
|
net = Net(reduction)
|
||||||
|
compile_net(net, logits_, labels_)
|
||||||
|
|
||||||
|
|
||||||
|
def test_kldiv_loss_sum_auto_parallel():
|
||||||
|
"""
|
||||||
|
Features: test KLDivLoss auto parallel
|
||||||
|
Description: auto parallel, reduction is 'sum'
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0, full_batch=True)
|
||||||
|
reduction = 'sum'
|
||||||
|
net = Net(reduction)
|
||||||
|
compile_net(net, logits_, labels_)
|
||||||
|
|
||||||
|
|
||||||
|
def test_kldiv_loss_mean_data_parallel():
|
||||||
|
"""
|
||||||
|
Features: test KLDivLoss data parallel
|
||||||
|
Description: data parallel, reduction is 'mean'
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
|
||||||
|
reduction = 'mean'
|
||||||
|
net = Net(reduction)
|
||||||
|
phase = compile_net(net, logits_, labels_)
|
||||||
|
validator = ParallelValidator(net, phase)
|
||||||
|
assert validator.check_node_inputs('AllReduce-0', ['KLDivLoss-0'])
|
||||||
|
assert validator.check_node_attrs('AllReduce-0', {'op': 'sum'})
|
||||||
|
|
||||||
|
|
||||||
|
def test_kldiv_loss_none_data_parallel():
|
||||||
|
"""
|
||||||
|
Features: test KLDivLoss data parallel
|
||||||
|
Description: data parallel, reduction is 'none'
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
|
||||||
|
reduction = 'none'
|
||||||
|
net = Net(reduction)
|
||||||
|
compile_net(net, logits_, labels_)
|
||||||
|
|
||||||
|
|
||||||
|
def test_kldiv_loss_none_model_parallel():
|
||||||
|
"""
|
||||||
|
Features: test KLDivLoss model parallel
|
||||||
|
Description: model parallel, reduction is 'none'
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
|
||||||
|
reduction = 'none'
|
||||||
|
strategy = ((2, 2), (2, 2))
|
||||||
|
net = Net(reduction, strategy)
|
||||||
|
compile_net(net, logits_, labels_)
|
||||||
|
|
||||||
|
|
||||||
|
def test_kldiv_loss_mean_model_parallel():
|
||||||
|
"""
|
||||||
|
Features: test KLDivLoss model parallel
|
||||||
|
Description: model parallel, reduction is 'mean'
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
|
||||||
|
reduction = 'mean'
|
||||||
|
strategy = ((4, 2), (4, 2))
|
||||||
|
net = Net(reduction, strategy)
|
||||||
|
phase = compile_net(net, logits_, labels_)
|
||||||
|
validator = ParallelValidator(net, phase)
|
||||||
|
assert validator.check_node_inputs('AllReduce-0', ['KLDivLoss-0'])
|
||||||
|
assert validator.check_node_attrs('AllReduce-0', {'op': 'sum'})
|
|
@ -18,7 +18,7 @@ from mindspore import Tensor, context
|
||||||
from mindspore.nn import Cell
|
from mindspore.nn import Cell
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
from parallel.utils.utils import compile_net
|
from parallel.utils.utils import ParallelValidator, compile_net
|
||||||
|
|
||||||
x_ = Tensor(np.random.normal(size=[32, 8, 8]).astype(np.float32))
|
x_ = Tensor(np.random.normal(size=[32, 8, 8]).astype(np.float32))
|
||||||
|
|
||||||
|
@ -52,4 +52,21 @@ def test_l2_loss_model_parallel():
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
strategy = ((2, 2, 2),)
|
strategy = ((2, 2, 2),)
|
||||||
net = Net(strategy)
|
net = Net(strategy)
|
||||||
compile_net(net, x_)
|
phase = compile_net(net, x_)
|
||||||
|
validator = ParallelValidator(net, phase)
|
||||||
|
assert validator.check_node_inputs('AllReduce-0', ['L2Loss-0'])
|
||||||
|
assert validator.check_node_attrs('AllReduce-0', {'op': 'sum'})
|
||||||
|
|
||||||
|
|
||||||
|
def test_l2_loss_data_parallel():
|
||||||
|
"""
|
||||||
|
Feature: test L2Loss data parallel
|
||||||
|
Description: data parallel
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
net = Net()
|
||||||
|
phase = compile_net(net, x_)
|
||||||
|
validator = ParallelValidator(net, phase)
|
||||||
|
assert validator.check_node_inputs('AllReduce-0', ['L2Loss-0'])
|
||||||
|
assert validator.check_node_attrs('AllReduce-0', {'op': 'sum'})
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import Tensor, context
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
from parallel.utils.utils import ParallelValidator, compile_net
|
||||||
|
|
||||||
|
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self, strategy=None):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.lin_space = P.LinSpace().shard(strategy)
|
||||||
|
|
||||||
|
def construct(self, start, end, x):
|
||||||
|
return self.lin_space(start, end, x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lin_space_loss_auto_parallel():
|
||||||
|
"""
|
||||||
|
Feature: test LinSpace auto parallel
|
||||||
|
Description: auto parallel
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||||
|
start = Tensor(1, mstype.float32)
|
||||||
|
end = Tensor(10, mstype.float32)
|
||||||
|
x = 8
|
||||||
|
net = Net()
|
||||||
|
compile_net(net, start, end, x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lin_space_parallel_with_repeated_cal():
|
||||||
|
"""
|
||||||
|
Feature: test LinSpace parallel
|
||||||
|
Description: parallel with repeated_cal
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
start = Tensor(1, mstype.float32)
|
||||||
|
end = Tensor(10, mstype.float32)
|
||||||
|
x = 8
|
||||||
|
strategy = ((4,),)
|
||||||
|
net = Net(strategy)
|
||||||
|
phase = compile_net(net, start, end, x)
|
||||||
|
validator = ParallelValidator(net, phase)
|
||||||
|
assert validator.check_node_inputs('LinSpace-0', ['Add-0', 'Add-1', 2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_lin_space_parallel_with_x_2():
|
||||||
|
"""
|
||||||
|
Feature: test LinSpace parallel
|
||||||
|
Description: parallel with input x is 2
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
start = Tensor(1, mstype.float32)
|
||||||
|
end = Tensor(10, mstype.float32)
|
||||||
|
x = 2
|
||||||
|
strategy = ((2,),)
|
||||||
|
net = Net(strategy)
|
||||||
|
phase = compile_net(net, start, end, x)
|
||||||
|
validator = ParallelValidator(net, phase)
|
||||||
|
assert validator.check_node_inputs('LinSpace-0', ['Add-0', 'Add-1', 1])
|
||||||
|
|
||||||
|
|
||||||
|
def test_lin_space_data_parallel():
|
||||||
|
"""
|
||||||
|
Feature: test LinSpace data parallel
|
||||||
|
Description: data parallel
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
start = Tensor(1, mstype.float32)
|
||||||
|
end = Tensor(10, mstype.float32)
|
||||||
|
x = 8
|
||||||
|
net = Net()
|
||||||
|
phase = compile_net(net, start, end, x)
|
||||||
|
validator = ParallelValidator(net, phase)
|
||||||
|
assert validator.check_node_inputs('LinSpace-0', ['Add-0', 'Add-1', 1])
|
||||||
|
|
||||||
|
|
||||||
|
def test_lin_space_parallel_strategy_wrong():
|
||||||
|
"""
|
||||||
|
Feature: test LinSpace parallel with invalid strategy
|
||||||
|
Description: invalid strategy
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
start = Tensor(1, mstype.float32)
|
||||||
|
end = Tensor(10, mstype.float32)
|
||||||
|
x = 6
|
||||||
|
strategy = ((4,),)
|
||||||
|
net = Net(strategy)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net, start, end, x)
|
|
@ -0,0 +1,118 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
from parallel.utils.utils import compile_net
|
||||||
|
|
||||||
|
|
||||||
|
input_x_tensor_ = Tensor(np.ones([8, 8, 8]), ms.float32)
|
||||||
|
input_x_parameter_ = Parameter(Tensor(np.ones([8, 8, 8]), ms.float32), "input_x")
|
||||||
|
SIZE_SPLIT = [3, 3, 2]
|
||||||
|
NUM_SPLIT = 3
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, size_split, split_dim, num_split, strategy1=None, strategy2=None):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.splitv = P.SplitV(size_split, split_dim, num_split).shard(strategy1)
|
||||||
|
self.mul = P.Mul().shard(strategy2)
|
||||||
|
self.weight = Parameter(np.array([1.0]), name="mul_weight")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.splitv(x)
|
||||||
|
out = self.mul(out[0], self.weight)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class NetWithParameter(nn.Cell):
|
||||||
|
def __init__(self, size_split, split_dim, num_split, strategy1=None, strategy2=None):
|
||||||
|
super(NetWithParameter, self).__init__()
|
||||||
|
self.splitv = P.SplitV(size_split, split_dim, num_split).shard(strategy1)
|
||||||
|
self.mul = P.Mul().shard(strategy2)
|
||||||
|
self.weight = input_x_parameter_
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.splitv(self.weight)
|
||||||
|
out = self.mul(x, out[0])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def test_splitv_auto_parallel():
|
||||||
|
"""
|
||||||
|
Feature: test SplitV auto parallel
|
||||||
|
Description: auto parallel
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||||
|
net = Net(SIZE_SPLIT, 0, NUM_SPLIT)
|
||||||
|
compile_net(net, input_x_tensor_)
|
||||||
|
|
||||||
|
|
||||||
|
def test_splitv_auto_parallel_with_parameter():
|
||||||
|
"""
|
||||||
|
Feature: test SplitV auto parallel with parameter input
|
||||||
|
Description: auto parallel with parameter input
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||||
|
net = NetWithParameter(SIZE_SPLIT, 2, NUM_SPLIT)
|
||||||
|
x = Tensor(np.ones([8, 8, 3]), ms.float32)
|
||||||
|
compile_net(net, x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_splitv_data_parallel():
|
||||||
|
"""
|
||||||
|
Feature: test SplitV data parallel
|
||||||
|
Description: data parallel
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
net = Net(SIZE_SPLIT, 1, NUM_SPLIT)
|
||||||
|
compile_net(net, input_x_tensor_)
|
||||||
|
|
||||||
|
|
||||||
|
def test_splitv_model_parallel():
|
||||||
|
"""
|
||||||
|
Feature: test SplitV model parallel
|
||||||
|
Description: model parallel
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((1, 2, 2),)
|
||||||
|
strategy2 = ((1, 2, 2), (1,))
|
||||||
|
net = Net(SIZE_SPLIT, 0, NUM_SPLIT, strategy1, strategy2)
|
||||||
|
compile_net(net, input_x_tensor_)
|
||||||
|
|
||||||
|
|
||||||
|
def test_splitv_strategy_error():
|
||||||
|
"""
|
||||||
|
Feature: test SplitV parallel with invalid strategy
|
||||||
|
Description: config invalid strategy
|
||||||
|
Expectation: raise RuntimeError
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((2, 2, 2),)
|
||||||
|
strategy2 = ((8, 1, 1),)
|
||||||
|
net = Net(SIZE_SPLIT, 0, NUM_SPLIT, strategy1, strategy2)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net, input_x_tensor_)
|
||||||
|
context.reset_auto_parallel_context()
|
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
from parallel.utils.utils import ParallelValidator, compile_net
|
||||||
|
|
||||||
|
SEED_ = 1
|
||||||
|
SEED2_ = 1
|
||||||
|
|
||||||
|
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self, seed, seed2, strategy=None):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.uniform_real = P.UniformReal(seed, seed2).shard(strategy)
|
||||||
|
|
||||||
|
def construct(self, shape):
|
||||||
|
out = self.uniform_real(shape)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_real_auto_parallel():
|
||||||
|
"""
|
||||||
|
Features: test UniformReal auto parallel
|
||||||
|
Description: auto parallel
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||||
|
net = Net(SEED_, SEED2_)
|
||||||
|
shape = (4, 4, 4)
|
||||||
|
compile_net(net, shape)
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_real_data_parallel():
|
||||||
|
"""
|
||||||
|
Features: test UniformReal data parallel
|
||||||
|
Description: data parallel
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
|
||||||
|
net = Net(SEED_, SEED2_)
|
||||||
|
shape = (8, 8)
|
||||||
|
phase = compile_net(net, shape)
|
||||||
|
|
||||||
|
validator = ParallelValidator(net, phase)
|
||||||
|
assert validator.check_node_attrs("UniformReal-0", {"seed": 2, "seed2": 2})
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_real_model_parallel():
|
||||||
|
"""
|
||||||
|
Features: test UniformReal model parallel
|
||||||
|
Description: model parallel
|
||||||
|
Expectation: compile success
|
||||||
|
"""
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
|
||||||
|
shape = (8, 8)
|
||||||
|
strategy = ((2, 2),)
|
||||||
|
net = Net(SEED_, SEED2_, strategy)
|
||||||
|
phase = compile_net(net, shape)
|
||||||
|
validator = ParallelValidator(net, phase)
|
||||||
|
assert validator.check_node_attrs("UniformReal-0", {"seed": 3, "seed2": 3})
|
Loading…
Reference in New Issue