forked from mindspore-Ecosystem/mindspore
!31252 Implementation of element wise parallel ops
Merge pull request !31252 from liuluobin/element_wise_ops
This commit is contained in:
commit
bf03f0e030
|
@ -191,6 +191,8 @@ using ResizeBilinearCost = CastCost;
|
|||
using BoundingBoxEncodeCost = CastCost;
|
||||
using IOUCost = CastCost;
|
||||
using RandomChoicWithMaskCost = CastCost;
|
||||
using IsFiniteCost = CastCost;
|
||||
using RintCost = CastCost;
|
||||
|
||||
class SqrtCost : public CastCost {
|
||||
public:
|
||||
|
@ -211,6 +213,11 @@ using AsinhCost = SqrtCost;
|
|||
using AcoshCost = SqrtCost;
|
||||
using ReLUV2Cost = SqrtCost;
|
||||
using TopKCost = SqrtCost;
|
||||
using HShrinkCost = SqrtCost;
|
||||
using HSigmoidCost = SqrtCost;
|
||||
using MishCost = SqrtCost;
|
||||
using SeLUCost = SqrtCost;
|
||||
using SoftShrinkCost = SqrtCost;
|
||||
|
||||
class ReLU6Cost : public CastCost {
|
||||
public:
|
||||
|
@ -240,6 +247,7 @@ using ErfCost = ReLU6Cost;
|
|||
using ErfcCost = ReLU6Cost;
|
||||
using ActivationInfoCost = ReLU6Cost;
|
||||
using SelectCost = ReLU6Cost;
|
||||
using XlogyCost = ReLU6Cost;
|
||||
|
||||
class TransposeCost : public CastCost {
|
||||
public:
|
||||
|
@ -289,6 +297,9 @@ class SoftmaxCost : public OperatorCost {
|
|||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
using CumSumCost = SoftmaxCost;
|
||||
using CumProdCost = SoftmaxCost;
|
||||
|
||||
class TileCost : public SoftmaxCost {
|
||||
public:
|
||||
TileCost() : SoftmaxCost() {}
|
||||
|
@ -619,6 +630,11 @@ using GreaterEqualCost = SubCost;
|
|||
using LessCost = SubCost;
|
||||
using LessEqualCost = SubCost;
|
||||
using GatherNdCost = SubCost;
|
||||
using BitwiseAndCost = SubCost;
|
||||
using BitwiseOrCost = SubCost;
|
||||
using BitwiseXorCost = SubCost;
|
||||
using AddNCost = SubCost;
|
||||
using InplaceAddCost = SubCost;
|
||||
|
||||
class MulCost : public SubCost {
|
||||
public:
|
||||
|
@ -628,7 +644,9 @@ class MulCost : public SubCost {
|
|||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
|
||||
using MulNoNanCost = MulCost;
|
||||
using GatherDCost = MulCost;
|
||||
using LerpCost = MulCost;
|
||||
|
||||
class DivCost : public SubCost {
|
||||
public:
|
||||
|
@ -640,6 +658,9 @@ class DivCost : public SubCost {
|
|||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using ReadDivCost = DivCost;
|
||||
using TruncateDivCost = DivCost;
|
||||
using XdivyCost = DivCost;
|
||||
using CdistCost = DivCost;
|
||||
|
||||
class ModCost : public SubCost {
|
||||
public:
|
||||
|
@ -649,6 +670,7 @@ class ModCost : public SubCost {
|
|||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using FloorModCost = ModCost;
|
||||
using TruncateModCost = ModCost;
|
||||
|
||||
class PowCost : public SubCost {
|
||||
public:
|
||||
|
@ -702,6 +724,7 @@ class MaximumCost : public SubCost {
|
|||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using MinimumCost = MaximumCost;
|
||||
using CumminCost = MaximumCost;
|
||||
|
||||
class SliceCost : public CastCost {
|
||||
public:
|
||||
|
@ -754,6 +777,7 @@ class ReduceSumCost : public OperatorCost {
|
|||
using ReduceMethodCost = ReduceSumCost;
|
||||
using ReduceProdCost = ReduceSumCost;
|
||||
using SquareSumAllCost = ReduceSumCost;
|
||||
using L2LossCost = ReduceSumCost;
|
||||
|
||||
class ReduceMeanCost : public ReduceSumCost {
|
||||
public:
|
||||
|
|
|
@ -221,6 +221,28 @@ REGISTER(ArgmaxInfo);
|
|||
REGISTER(ArgminInfo);
|
||||
REGISTER(UnsortedSegmentProdInfo);
|
||||
REGISTER(SquareSumAllInfo);
|
||||
REGISTER(AddNInfo);
|
||||
REGISTER(BitwiseAndInfo);
|
||||
REGISTER(BitwiseOrInfo);
|
||||
REGISTER(BitwiseXorInfo);
|
||||
REGISTER(CumProdInfo);
|
||||
REGISTER(HShrinkInfo);
|
||||
REGISTER(HSigmoidInfo);
|
||||
REGISTER(IsFiniteInfo);
|
||||
REGISTER(MishInfo);
|
||||
REGISTER(MulNoNanInfo);
|
||||
REGISTER(RintInfo);
|
||||
REGISTER(SeLUInfo);
|
||||
REGISTER(SoftShrinkInfo);
|
||||
REGISTER(TruncateDivInfo);
|
||||
REGISTER(TruncateModInfo);
|
||||
REGISTER(XdivyInfo);
|
||||
REGISTER(XlogyInfo);
|
||||
REGISTER(InplaceAddInfo);
|
||||
REGISTER(InplaceSubInfo);
|
||||
REGISTER(CdistInfo);
|
||||
REGISTER(L2LossInfo);
|
||||
REGISTER(LerpInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "frontend/parallel/auto_parallel/costmodel.h"
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/redistribution_operator_infer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -204,8 +205,8 @@ std::vector<StrategyPtr> Softmax::GenerateOpStrategies(int64_t stage_id) {
|
|||
return sp_vector;
|
||||
}
|
||||
|
||||
Status CumSumInfo::GetAttrs() {
|
||||
if (input_value_.size() != CUMSUM_INPUT_SIZE) {
|
||||
Status CumOpBase::GetAttrs() {
|
||||
if (input_value_.size() != CUM_OP_INPUT_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid inputs size " << input_value_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -237,7 +238,7 @@ Status CumSumInfo::GetAttrs() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status CumSumInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
Status CumOpBase::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
return FAILED;
|
||||
|
@ -259,7 +260,7 @@ Status CumSumInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> CumSumInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
std::vector<StrategyPtr> CumOpBase::GenerateOpStrategies(int64_t stage_id) {
|
||||
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||
if (axis_ < 0 || LongToSize(axis_) >= inputs_shape_[0].size()) {
|
||||
MS_LOG(EXCEPTION) << "Wrong axis value: " << axis_;
|
||||
|
@ -275,31 +276,22 @@ std::vector<StrategyPtr> CumSumInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
return sp_vector;
|
||||
}
|
||||
|
||||
Status CumSumInfo::InferMirrorOps() {
|
||||
mirror_ops_.clear();
|
||||
Shape input_a_tensor_map = inputs_tensor_map_.at(0);
|
||||
std::vector<Group> input_a_group;
|
||||
if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) {
|
||||
ReportError(name_ + ": Create group for input a failed.");
|
||||
void CumOpBase::ReComputeBatchSplitFlagList() { axis_ == 0 ? split_flag_list_[0] = false : split_flag_list_[0] = true; }
|
||||
|
||||
Status CumOpBase::InferMirrorOps() {
|
||||
if (OperatorInfo::InferMirrorOps() != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
OperatorVector op_for_input_a, op_for_axis;
|
||||
if (input_a_group.empty()) {
|
||||
MS_LOG(INFO) << name_ << ": The mirror group is empty.";
|
||||
// No need to insert mirror ops
|
||||
if (mirror_ops_.empty()) {
|
||||
return SUCCESS;
|
||||
} else {
|
||||
op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum());
|
||||
MS_LOG(INFO) << name_ << ": Create the mirror ops for input a success, groups is " << input_a_group[0].name();
|
||||
}
|
||||
|
||||
mirror_ops_.push_back(op_for_input_a);
|
||||
mirror_ops_.push_back(op_for_axis);
|
||||
|
||||
OperatorVector op_for_axis;
|
||||
(void)mirror_ops_.emplace_back(std::move(op_for_axis));
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status CumSumInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
Status ActivationBase::InferDevMatrixShape() {
|
||||
Strategys stra = strategy_->GetInputDim();
|
||||
Dimensions input_strategy = stra.at(0);
|
||||
|
@ -613,5 +605,14 @@ Status SqueezeInfo::InferTensorMap() {
|
|||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status L2LossInfo::InferTensorMap() {
|
||||
if (ActivationOther::InferTensorMap() != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
// outputs_shape is [], so clearing its tensor map.
|
||||
outputs_tensor_map_[0].clear();
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -119,24 +119,6 @@ class Softmax : public ActivationBase {
|
|||
std::vector<int64_t> axis_;
|
||||
};
|
||||
|
||||
class CumSumInfo : public ActivationBase {
|
||||
public:
|
||||
explicit CumSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {}
|
||||
~CumSumInfo() override = default;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferMirrorOps() override;
|
||||
Status GetAttrs() override;
|
||||
|
||||
private:
|
||||
int64_t axis_ = -1;
|
||||
};
|
||||
|
||||
class SoftmaxInfo : public Softmax {
|
||||
public:
|
||||
SoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
|
@ -153,6 +135,40 @@ class LogSoftmaxInfo : public Softmax {
|
|||
~LogSoftmaxInfo() override = default;
|
||||
};
|
||||
|
||||
class CumOpBase : public ActivationBase {
|
||||
public:
|
||||
CumOpBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs, OperatorCostPtr cost)
|
||||
: ActivationBase(name, inputs_shape, outputs_shape, attrs, cost) {}
|
||||
~CumOpBase() override = default;
|
||||
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferMirrorOps() override;
|
||||
Status GetAttrs() override;
|
||||
int axis_ = -1;
|
||||
};
|
||||
|
||||
class CumSumInfo : public CumOpBase {
|
||||
public:
|
||||
CumSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: CumOpBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<CumSumCost>()) {}
|
||||
~CumSumInfo() override = default;
|
||||
};
|
||||
|
||||
class CumProdInfo : public CumOpBase {
|
||||
public:
|
||||
CumProdInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: CumOpBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<CumProdCost>()) {}
|
||||
~CumProdInfo() = default;
|
||||
};
|
||||
|
||||
class EluInfo : public ActivationOther {
|
||||
public:
|
||||
EluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
|
||||
|
@ -299,6 +315,73 @@ class DropoutInfo : public ActivationOther {
|
|||
return ++SEED_NUM;
|
||||
}
|
||||
};
|
||||
|
||||
class HShrinkInfo : public ActivationOther {
|
||||
public:
|
||||
HShrinkInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<HShrinkCost>()) {}
|
||||
~HShrinkInfo() = default;
|
||||
};
|
||||
|
||||
class HSigmoidInfo : public ActivationOther {
|
||||
public:
|
||||
HSigmoidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<HSigmoidCost>()) {}
|
||||
~HSigmoidInfo() = default;
|
||||
};
|
||||
|
||||
class IsFiniteInfo : public ActivationOther {
|
||||
public:
|
||||
IsFiniteInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<IsFiniteCost>()) {}
|
||||
~IsFiniteInfo() = default;
|
||||
};
|
||||
|
||||
class MishInfo : public ActivationOther {
|
||||
public:
|
||||
MishInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<MishCost>()) {}
|
||||
~MishInfo() = default;
|
||||
};
|
||||
|
||||
class RintInfo : public ActivationOther {
|
||||
public:
|
||||
RintInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RintCost>()) {}
|
||||
~RintInfo() = default;
|
||||
};
|
||||
|
||||
class SeLUInfo : public ActivationOther {
|
||||
public:
|
||||
SeLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SeLUCost>()) {}
|
||||
~SeLUInfo() = default;
|
||||
};
|
||||
|
||||
class SoftShrinkInfo : public ActivationOther {
|
||||
public:
|
||||
SoftShrinkInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftShrinkCost>()) {}
|
||||
~SoftShrinkInfo() override = default;
|
||||
};
|
||||
|
||||
class L2LossInfo : public ActivationOther {
|
||||
public:
|
||||
L2LossInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &output_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, output_shape, attrs, std::make_shared<L2LossCost>()) {}
|
||||
~L2LossInfo() = default;
|
||||
|
||||
protected:
|
||||
Status InferTensorMap() override;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
|
||||
#include "frontend/parallel/ops_info/addn_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status AddNInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// The strategy for each input tensor must be equal
|
||||
Strategys strategies = strategy->GetInputDim();
|
||||
for (size_t i = 1; i < strategies.size(); ++i) {
|
||||
if (strategies[i] != strategies[0]) {
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AddNInfo::InferDevMatrixShape() {
|
||||
dev_matrix_shape_.clear();
|
||||
|
||||
Strategys strategies = strategy_->GetInputDim();
|
||||
if (strategies.empty()) {
|
||||
return SUCCESS;
|
||||
}
|
||||
dev_matrix_shape_.assign(strategies[0].begin(), strategies[0].end());
|
||||
|
||||
MS_LOG(INFO) << name_ << ": dev matrix: " << ShapeToString(dev_matrix_shape_);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status AddNInfo::InferTensorMap() {
|
||||
inputs_tensor_map_.clear();
|
||||
outputs_tensor_map_.clear();
|
||||
|
||||
Shape sub_tensor_map;
|
||||
size_t dev_size = dev_matrix_shape_.size();
|
||||
for (size_t i = 0; i < dev_size; ++i) {
|
||||
sub_tensor_map.push_back(dev_size - i - 1);
|
||||
}
|
||||
|
||||
Strategys strategies = strategy_->GetInputDim();
|
||||
for (size_t i = 0; i < strategies.size(); ++i) {
|
||||
inputs_tensor_map_.push_back(sub_tensor_map);
|
||||
}
|
||||
(void)outputs_tensor_map_.emplace_back(std::move(sub_tensor_map));
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> AddNInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
Shapes splittable_inputs;
|
||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||
(void)splittable_inputs.emplace_back(inputs_shape_[i].size());
|
||||
for (size_t j = 0; j < inputs_shape_[i].size(); ++j) {
|
||||
splittable_inputs[i][j] = j + 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
if (GenerateStrategiesForDependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Generate strategies for dependent inputs() failed.";
|
||||
}
|
||||
|
||||
return sp_vector;
|
||||
}
|
||||
|
||||
void AddNInfo::ReComputeBatchSplitFlagList() {
|
||||
bool flag = false;
|
||||
if (!inputs_shape_[0].empty()) {
|
||||
flag = true;
|
||||
}
|
||||
|
||||
// Batch dim of each input can be split
|
||||
for (size_t i = 0; i < split_flag_list_.size(); ++i) {
|
||||
split_flag_list_[i] = flag;
|
||||
}
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ADDN_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ADDN_INFO_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class AddNInfo : public OperatorInfo {
|
||||
public:
|
||||
AddNInfo(const std::string &name, const Shapes &input_shape, const Shapes &output_shape, const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, input_shape, output_shape, attrs, std::make_shared<AddNCost>()) {}
|
||||
~AddNInfo() = default;
|
||||
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ADDN_INFO_H_
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Shape ExpendShape(const Shape &bigger_size_shape, Shape smaller_size_shape) {
|
||||
Shape ExpandShape(const Shape &bigger_size_shape, Shape smaller_size_shape) {
|
||||
size_t insert_num = bigger_size_shape.size() - smaller_size_shape.size();
|
||||
for (size_t num = 0; num < insert_num; ++num) {
|
||||
(void)smaller_size_shape.insert(smaller_size_shape.begin(), 1);
|
||||
|
@ -34,7 +34,7 @@ Shape ExpendShape(const Shape &bigger_size_shape, Shape smaller_size_shape) {
|
|||
return smaller_size_shape;
|
||||
}
|
||||
|
||||
Shapes ArithmeticBase::InferExpendShape() {
|
||||
Shapes ArithmeticBase::InferExpandShape() {
|
||||
Shape input_a_shape = inputs_shape_.at(0);
|
||||
Shape input_b_shape = inputs_shape_.at(1);
|
||||
Shapes input_shapes;
|
||||
|
@ -42,9 +42,9 @@ Shapes ArithmeticBase::InferExpendShape() {
|
|||
size_t input_b_size = input_b_shape.size();
|
||||
if (input_a_size > input_b_size) {
|
||||
input_shapes.push_back(input_a_shape);
|
||||
input_shapes.push_back(ExpendShape(input_a_shape, input_b_shape));
|
||||
input_shapes.push_back(ExpandShape(input_a_shape, input_b_shape));
|
||||
} else if (input_a_size < input_b_size) {
|
||||
input_shapes.push_back(ExpendShape(input_b_shape, input_a_shape));
|
||||
input_shapes.push_back(ExpandShape(input_b_shape, input_a_shape));
|
||||
input_shapes.push_back(input_b_shape);
|
||||
} else {
|
||||
input_shapes.push_back(input_a_shape);
|
||||
|
@ -53,23 +53,23 @@ Shapes ArithmeticBase::InferExpendShape() {
|
|||
return input_shapes;
|
||||
}
|
||||
|
||||
Strategys ExpendStrategy(const StrategyPtr &strategy) {
|
||||
Strategys expend_strategy;
|
||||
Strategys ExpandStrategy(const StrategyPtr &strategy) {
|
||||
Strategys expand_strategy;
|
||||
Strategys stra = strategy->GetInputDim();
|
||||
Dimensions sub_a_strategy = stra.at(0);
|
||||
Dimensions sub_b_strategy = stra.at(1);
|
||||
size_t input_a_size = sub_a_strategy.size();
|
||||
size_t input_b_size = sub_b_strategy.size();
|
||||
if (input_a_size > input_b_size) {
|
||||
expend_strategy.push_back(sub_a_strategy);
|
||||
expend_strategy.push_back(ExpendShape(sub_a_strategy, sub_b_strategy));
|
||||
expand_strategy.push_back(sub_a_strategy);
|
||||
expand_strategy.push_back(ExpandShape(sub_a_strategy, sub_b_strategy));
|
||||
} else if (input_a_size < input_b_size) {
|
||||
expend_strategy.push_back(ExpendShape(sub_b_strategy, sub_a_strategy));
|
||||
expend_strategy.push_back(sub_b_strategy);
|
||||
expand_strategy.push_back(ExpandShape(sub_b_strategy, sub_a_strategy));
|
||||
expand_strategy.push_back(sub_b_strategy);
|
||||
} else {
|
||||
expend_strategy = stra;
|
||||
expand_strategy = stra;
|
||||
}
|
||||
return expend_strategy;
|
||||
return expand_strategy;
|
||||
}
|
||||
|
||||
Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) {
|
||||
|
@ -77,10 +77,10 @@ Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) {
|
|||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
Shapes input_shapes = InferExpendShape();
|
||||
Strategys expend_strategy = ExpendStrategy(strategy);
|
||||
Dimensions sub_a_strategy = expend_strategy.at(0);
|
||||
Dimensions sub_b_strategy = expend_strategy.at(1);
|
||||
Shapes input_shapes = InferExpandShape();
|
||||
Strategys expand_strategy = ExpandStrategy(strategy);
|
||||
Dimensions sub_a_strategy = expand_strategy.at(0);
|
||||
Dimensions sub_b_strategy = expand_strategy.at(1);
|
||||
Shape input_a_shape = input_shapes.at(0);
|
||||
Shape input_b_shape = input_shapes.at(1);
|
||||
|
||||
|
@ -94,9 +94,9 @@ Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) {
|
|||
}
|
||||
|
||||
Status ArithmeticBase::InferDevMatrixShape() {
|
||||
Strategys expend_strategy = ExpendStrategy(strategy_);
|
||||
Dimensions sub_a_strategy = expend_strategy.at(0);
|
||||
Dimensions sub_b_strategy = expend_strategy.at(1);
|
||||
Strategys expand_strategy = ExpandStrategy(strategy_);
|
||||
Dimensions sub_a_strategy = expand_strategy.at(0);
|
||||
Dimensions sub_b_strategy = expand_strategy.at(1);
|
||||
Shape dev_shape;
|
||||
for (size_t i = 0; i < sub_a_strategy.size(); ++i) {
|
||||
if (sub_a_strategy[i] != sub_b_strategy[i]) {
|
||||
|
@ -110,7 +110,7 @@ Status ArithmeticBase::InferDevMatrixShape() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
TensorMap SetExpendTensorMap(const Shape &strategy, const Shape &dev_matrix_shape) {
|
||||
TensorMap SetExpandTensorMap(const Shape &strategy, const Shape &dev_matrix_shape) {
|
||||
TensorMap tensor_map_index;
|
||||
for (size_t i = 0; i < strategy.size(); ++i) {
|
||||
if (strategy[i] == dev_matrix_shape[i]) {
|
||||
|
@ -122,56 +122,58 @@ TensorMap SetExpendTensorMap(const Shape &strategy, const Shape &dev_matrix_shap
|
|||
return tensor_map_index;
|
||||
}
|
||||
|
||||
TensorMap SetTensorMap(const Shape &strategy_expend, const Shape &dev_matrix_shape, const Shape &strategy) {
|
||||
TensorMap expend_map = SetExpendTensorMap(strategy_expend, dev_matrix_shape);
|
||||
TensorMap SetTensorMap(const Shape &strategy_expand, const Shape &dev_matrix_shape, const Shape &strategy) {
|
||||
TensorMap expand_map = SetExpandTensorMap(strategy_expand, dev_matrix_shape);
|
||||
size_t dev_matrix_size = dev_matrix_shape.size();
|
||||
size_t strategy_size = strategy.size();
|
||||
if (dev_matrix_size != strategy_size) {
|
||||
(void)expend_map.erase(expend_map.begin(),
|
||||
expend_map.begin() + static_cast<different_type>(dev_matrix_size - strategy_size));
|
||||
(void)expand_map.erase(expand_map.begin(),
|
||||
expand_map.begin() + static_cast<different_type>(dev_matrix_size - strategy_size));
|
||||
}
|
||||
return expend_map;
|
||||
return expand_map;
|
||||
}
|
||||
|
||||
void ArithmeticBase::ReComputeBatchSplitFlagList() {
|
||||
Shapes expend_shapes = InferExpendShape();
|
||||
Shape expend_a_shape = expend_shapes.at(0);
|
||||
Shape expend_b_shape = expend_shapes.at(1);
|
||||
if (expend_a_shape.size() != expend_b_shape.size()) {
|
||||
Shapes expand_shapes = InferExpandShape();
|
||||
Shape expand_a_shape = expand_shapes.at(0);
|
||||
Shape expand_b_shape = expand_shapes.at(1);
|
||||
if (expand_a_shape.size() != expand_b_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << name_ << " : Recompute batch split flag list is wrong.";
|
||||
}
|
||||
if (expend_a_shape.empty()) {
|
||||
if (expand_a_shape.empty()) {
|
||||
split_flag_list_[0] = false;
|
||||
split_flag_list_[1] = false;
|
||||
return;
|
||||
}
|
||||
(expend_a_shape.at(0) != 1) ? (split_flag_list_[0] = true) : (split_flag_list_[0] = false);
|
||||
(expend_b_shape.at(0) != 1) ? (split_flag_list_[1] = true) : (split_flag_list_[1] = false);
|
||||
(expand_a_shape.at(0) != 1) ? (split_flag_list_[0] = true) : (split_flag_list_[0] = false);
|
||||
(expand_b_shape.at(0) != 1) ? (split_flag_list_[1] = true) : (split_flag_list_[1] = false);
|
||||
}
|
||||
|
||||
Status ArithmeticBase::InferTensorMap() {
|
||||
Shape tensor_map_index;
|
||||
Strategys expend_strategy = ExpendStrategy(strategy_);
|
||||
Dimensions sub_a_expend_strategy = expend_strategy.at(0);
|
||||
Dimensions sub_b_expend_strategy = expend_strategy.at(1);
|
||||
Strategys expand_strategy = ExpandStrategy(strategy_);
|
||||
Dimensions sub_a_expand_strategy = expand_strategy.at(0);
|
||||
Dimensions sub_b_expand_strategy = expand_strategy.at(1);
|
||||
Strategys stra = strategy_->GetInputDim();
|
||||
Dimensions sub_a_strategy = stra.at(0);
|
||||
Dimensions sub_b_strategy = stra.at(1);
|
||||
for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) {
|
||||
tensor_map_index.push_back((int64_t)(LAST_INDEX(sub_a_expend_strategy.size()) - i));
|
||||
for (size_t i = 0; i < sub_a_expand_strategy.size(); ++i) {
|
||||
tensor_map_index.push_back((int64_t)(LAST_INDEX(sub_a_expand_strategy.size()) - i));
|
||||
}
|
||||
|
||||
Shape dev_shape;
|
||||
for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) {
|
||||
if (sub_a_expend_strategy[i] != sub_b_expend_strategy[i]) {
|
||||
dev_shape.push_back(sub_a_expend_strategy[i] * sub_b_expend_strategy[i]);
|
||||
// Get dev matrix without repeated calculation
|
||||
Shape dev_shape = dev_matrix_shape_;
|
||||
if (repeated_calc_num_ > 1) {
|
||||
if (repeated_num_in_dev_matrix_right_) {
|
||||
dev_shape.pop_back();
|
||||
} else {
|
||||
dev_shape.push_back(sub_a_expend_strategy[i]);
|
||||
dev_shape.erase(dev_shape.begin());
|
||||
}
|
||||
}
|
||||
inputs_tensor_map_.push_back(SetTensorMap(sub_a_expend_strategy, dev_shape, sub_a_strategy));
|
||||
inputs_tensor_map_.push_back(SetTensorMap(sub_b_expend_strategy, dev_shape, sub_b_strategy));
|
||||
outputs_tensor_map_.push_back(tensor_map_index);
|
||||
|
||||
(void)inputs_tensor_map_.emplace_back(SetTensorMap(sub_a_expand_strategy, dev_shape, sub_a_strategy));
|
||||
(void)inputs_tensor_map_.emplace_back(SetTensorMap(sub_b_expand_strategy, dev_shape, sub_b_strategy));
|
||||
(void)outputs_tensor_map_.emplace_back(std::move(tensor_map_index));
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
@ -182,14 +184,172 @@ std::vector<StrategyPtr> ArithmeticBase::GenerateOpStrategies(int64_t stage_id)
|
|||
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||
Shape input1_split(inputs_shape_[1].size(), 1);
|
||||
Shapes splittable_inputs = {input0_split, input1_split};
|
||||
if (inputs_shape_.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Size of inputs must be greater than or equal to 2, but got size "
|
||||
<< inputs_shape_.size();
|
||||
}
|
||||
Shapes inputs_shape(inputs_shape_.begin(), inputs_shape_.begin() + 2);
|
||||
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||
if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << " : Generate strategies with broadcast failed.";
|
||||
}
|
||||
MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success.";
|
||||
|
||||
return sp_vector;
|
||||
}
|
||||
|
||||
Status LerpInfo::GetAttrs() {
|
||||
inputs_size_ = inputs_shape_.size();
|
||||
if (inputs_size_ != 2 && inputs_size_ != 3) {
|
||||
MS_LOG(ERROR) << name_ << ": Inputs size must be 2 or 3, but got size " << inputs_size_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status LerpInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (inputs_size_ == 2) {
|
||||
return ArithmeticBase::CheckStrategy(strategy);
|
||||
}
|
||||
|
||||
// validate strategy between 'start' and 'end'
|
||||
if (ArithmeticBase::CheckStrategy(strategy) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// validate strategy of weight
|
||||
Strategys expand_strategy = ExpandStrategy(strategy);
|
||||
Dimensions expand_begin_strategy = expand_strategy.at(0);
|
||||
Dimensions expand_end_strategy = expand_strategy.at(1);
|
||||
Dimensions expand_cmp_strategy;
|
||||
for (size_t i = 0; i < expand_begin_strategy.size(); ++i) {
|
||||
expand_cmp_strategy.push_back(std::max(expand_begin_strategy[i], expand_end_strategy[i]));
|
||||
}
|
||||
Dimensions expand_weight_strategy = ExpandShape(expand_cmp_strategy, strategy->GetInputDim().at(2));
|
||||
|
||||
Shapes input_shapes = InferExpandShape();
|
||||
Shape expand_begin_shape = input_shapes.at(0);
|
||||
Shape expand_end_shape = input_shapes.at(1);
|
||||
Shape expand_cmp_shape;
|
||||
for (size_t i = 0; i < expand_begin_shape.size(); ++i) {
|
||||
expand_cmp_shape.push_back(std::max(expand_begin_shape[i], expand_end_shape[i]));
|
||||
}
|
||||
Shape expand_weight_shape = ExpandShape(expand_cmp_shape, inputs_shape_[2]);
|
||||
|
||||
for (size_t i = 0; i < expand_cmp_shape.size(); ++i) {
|
||||
if ((expand_cmp_strategy[i] != expand_weight_strategy[i]) && (expand_cmp_shape[i] != 1) &&
|
||||
(expand_weight_shape[i] != 1)) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status LerpInfo::InferDevMatrixShape() {
|
||||
if (inputs_size_ == 2) {
|
||||
return ArithmeticBase::InferDevMatrixShape();
|
||||
}
|
||||
|
||||
dev_matrix_shape_.clear();
|
||||
Strategys expand_strategy = ExpandStrategy(strategy_);
|
||||
Dimensions expand_start_strategy = expand_strategy.at(0);
|
||||
Dimensions expand_end_strategy = expand_strategy.at(1);
|
||||
Dimensions expand_weight_strategy = ExpandShape(expand_start_strategy, strategy_->GetInputDim().at(2));
|
||||
for (size_t i = 0; i < expand_start_strategy.size(); ++i) {
|
||||
if (expand_start_strategy[i] == expand_end_strategy[i] && expand_start_strategy[i] == expand_weight_strategy[i]) {
|
||||
dev_matrix_shape_.push_back(expand_start_strategy[i]);
|
||||
} else {
|
||||
dev_matrix_shape_.push_back(
|
||||
std::max(std::max(expand_start_strategy[i], expand_end_strategy[i]), expand_weight_strategy[i]));
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The dev matrix is " << ShapeToString(dev_matrix_shape_);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status LerpInfo::InferTensorMap() {
|
||||
if (inputs_size_ == 2) {
|
||||
return ArithmeticBase::InferTensorMap();
|
||||
}
|
||||
|
||||
inputs_tensor_map_.clear();
|
||||
outputs_tensor_map_.clear();
|
||||
// Generate inputs tensor map for 'start' and end, outputs tensor map
|
||||
if (ArithmeticBase::InferTensorMap() != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
// Generate tensor map for 'weight'
|
||||
Strategys stra = strategy_->GetInputDim();
|
||||
Dimensions weight_strategy = stra.at(2);
|
||||
Strategys expand_strategy = ExpandStrategy(strategy_);
|
||||
Dimensions expand_start_strategy = expand_strategy.at(0);
|
||||
Dimensions expand_weight_strategy = ExpandShape(expand_start_strategy, weight_strategy);
|
||||
Shape dev_shape = dev_matrix_shape_;
|
||||
if (repeated_calc_num_ > 1) {
|
||||
if (repeated_num_in_dev_matrix_right_) {
|
||||
dev_shape.pop_back();
|
||||
} else {
|
||||
dev_shape.erase(dev_shape.begin());
|
||||
}
|
||||
}
|
||||
inputs_tensor_map_.push_back(SetTensorMap(expand_weight_strategy, dev_shape, weight_strategy));
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> LerpInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
if (inputs_size_ == 2) {
|
||||
return ArithmeticBase::GenerateOpStrategies(stage_id);
|
||||
}
|
||||
|
||||
// search strategy for 'start' and 'end'
|
||||
auto sub_sp_vector = ArithmeticBase::GenerateOpStrategies(stage_id);
|
||||
|
||||
// infer strategy for 'weight' according to strategy of 'start' and 'end'
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
for (const auto &sub_sp : sub_sp_vector) {
|
||||
auto expand_sub_strategies = ExpandStrategy(sub_sp);
|
||||
auto expand_start_strategy = expand_sub_strategies.at(0);
|
||||
auto expand_end_strategy = expand_sub_strategies.at(1);
|
||||
Dimensions expand_cmp_strategy;
|
||||
for (size_t i = 0; i < expand_start_strategy.size(); ++i) {
|
||||
expand_cmp_strategy.push_back(std::max(expand_start_strategy[i], expand_end_strategy[i]));
|
||||
}
|
||||
auto weight_shape = inputs_shape_.at(2);
|
||||
size_t offset = expand_cmp_strategy.size() - weight_shape.size();
|
||||
Dimensions weight_strategy;
|
||||
for (size_t i = 0; i < weight_shape.size(); ++i) {
|
||||
if (weight_shape[i] == 1) {
|
||||
weight_strategy.push_back(1);
|
||||
} else {
|
||||
weight_strategy.push_back(expand_cmp_strategy[offset + i]);
|
||||
}
|
||||
}
|
||||
auto strategies = sub_sp->GetInputDim();
|
||||
(void)strategies.emplace_back(weight_strategy);
|
||||
(void)sp_vector.emplace_back(std::make_shared<Strategy>(stage_id, strategies));
|
||||
}
|
||||
|
||||
return sp_vector;
|
||||
}
|
||||
|
||||
void LerpInfo::ReComputeBatchSplitFlagList() {
|
||||
// Set split flag for 'start' and 'end'
|
||||
ArithmeticBase::ReComputeBatchSplitFlagList();
|
||||
|
||||
// if 'weight' is float, return
|
||||
if (inputs_shape_.size() == 2) {
|
||||
return;
|
||||
}
|
||||
|
||||
// set split flag for 'weight'
|
||||
Shapes expand_shapes = InferExpandShape();
|
||||
Shape expand_a_shape = expand_shapes.at(0);
|
||||
Shape expand_weight_shape = ExpandShape(expand_a_shape, inputs_shape_.at(2));
|
||||
(expand_weight_shape.at(0) != 1) ? (split_flag_list_[2] = true) : (split_flag_list_[2] = false);
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -45,7 +45,7 @@ class ArithmeticBase : public OperatorInfo {
|
|||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Shapes InferExpendShape();
|
||||
Shapes InferExpandShape();
|
||||
};
|
||||
|
||||
class SubInfo : public ArithmeticBase {
|
||||
|
@ -179,6 +179,90 @@ class LogicalOrInfo : public ArithmeticBase {
|
|||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalOrCost>()) {}
|
||||
~LogicalOrInfo() override = default;
|
||||
};
|
||||
|
||||
class BitwiseAndInfo : public ArithmeticBase {
|
||||
public:
|
||||
BitwiseAndInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<BitwiseAndCost>()) {}
|
||||
~BitwiseAndInfo() override = default;
|
||||
};
|
||||
|
||||
class BitwiseOrInfo : public ArithmeticBase {
|
||||
public:
|
||||
BitwiseOrInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<BitwiseOrCost>()) {}
|
||||
~BitwiseOrInfo() override = default;
|
||||
};
|
||||
|
||||
class BitwiseXorInfo : public ArithmeticBase {
|
||||
public:
|
||||
BitwiseXorInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<BitwiseXorCost>()) {}
|
||||
~BitwiseXorInfo() override = default;
|
||||
};
|
||||
|
||||
class MulNoNanInfo : public ArithmeticBase {
|
||||
public:
|
||||
MulNoNanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MulNoNanCost>()) {}
|
||||
~MulNoNanInfo() = default;
|
||||
};
|
||||
|
||||
class TruncateDivInfo : public ArithmeticBase {
|
||||
public:
|
||||
TruncateDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TruncateDivCost>()) {}
|
||||
~TruncateDivInfo() = default;
|
||||
};
|
||||
|
||||
class TruncateModInfo : public ArithmeticBase {
|
||||
public:
|
||||
TruncateModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TruncateModCost>()) {}
|
||||
~TruncateModInfo() = default;
|
||||
};
|
||||
|
||||
class XdivyInfo : public ArithmeticBase {
|
||||
public:
|
||||
XdivyInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<XdivyCost>()) {}
|
||||
~XdivyInfo() = default;
|
||||
};
|
||||
|
||||
class XlogyInfo : public ArithmeticBase {
|
||||
public:
|
||||
XlogyInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<XlogyCost>()) {}
|
||||
~XlogyInfo() = default;
|
||||
};
|
||||
|
||||
class LerpInfo : public ArithmeticBase {
|
||||
public:
|
||||
LerpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LerpCost>()) {}
|
||||
~LerpInfo() = default;
|
||||
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
|
||||
private:
|
||||
size_t inputs_size_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -111,5 +111,12 @@ Status BoundingBoxEncodeInfo::PrepareStrategy(int64_t stage_id, int64_t split_nu
|
|||
(*sp) = std::make_shared<Strategy>(stage_id, strategies);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void BoundingBoxEncodeInfo::ReComputeBatchSplitFlagList() {
|
||||
auto anchor_box_shape = inputs_shape_.at(0);
|
||||
auto gt_box_shape = inputs_shape_.at(1);
|
||||
anchor_box_shape[0] == 1 ? split_flag_list_[0] = false : split_flag_list_[0] = true;
|
||||
gt_box_shape[0] == 1 ? split_flag_list_[1] = false : split_flag_list_[1] = true;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,6 +36,7 @@ class BoundingBoxEncodeInfo : public OperatorInfo {
|
|||
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "frontend/parallel/ops_info/cdist_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status CdistInfo::GetAttrs() {
|
||||
input_dims_ = inputs_shape_.at(0).size();
|
||||
if (input_dims_ != 2 && input_dims_ != 3) {
|
||||
MS_LOG(ERROR) << "Dimension of each input must be 2 or 3, but got dimension is " << input_dims_ << ".";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status CdistInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
auto strategies = strategy->GetInputDim();
|
||||
auto input_x_strategy = strategies.at(0);
|
||||
auto input_y_strategy = strategies.at(1);
|
||||
// input_x shape: (B, P, M), input_y shape: (B, R, M), shard num of B-dim must be equal
|
||||
if (input_dims_ == 3 && input_x_strategy[0] != input_y_strategy[0]) {
|
||||
MS_LOG(ERROR) << name_ << ": Sharding num of batch-dimension must be equal, "
|
||||
<< "but got strategy " << StrategyToString(strategies);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_x_strategy.back() != 1 || input_y_strategy.back() != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The last dimension of each input cannot be shard, "
|
||||
<< "but got strategy " << StrategyToString(strategies);
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status CdistInfo::InferDevMatrixShape() {
|
||||
dev_matrix_shape_.clear();
|
||||
|
||||
auto strategies = strategy_->GetInputDim();
|
||||
auto input_x_strategy = strategies.at(0);
|
||||
auto input_y_strategy = strategies.at(1);
|
||||
if (input_dims_ == 2) {
|
||||
dev_matrix_shape_ = {input_x_strategy[0], input_y_strategy[0]};
|
||||
} else {
|
||||
dev_matrix_shape_ = {input_x_strategy[0], input_x_strategy[1], input_y_strategy[1]};
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": dev matrix: " << ShapeToString(dev_matrix_shape_);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status CdistInfo::InferTensorMap() {
|
||||
inputs_tensor_map_.clear();
|
||||
outputs_tensor_map_.clear();
|
||||
|
||||
if (input_dims_ == 2) {
|
||||
inputs_tensor_map_ = {{1, -1}, {0, -1}};
|
||||
outputs_tensor_map_ = {{1, 0}};
|
||||
} else {
|
||||
inputs_tensor_map_ = {{2, 1, -1}, {2, 0, -1}};
|
||||
outputs_tensor_map_ = {{2, 1, 0}};
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> CdistInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
std::vector<StrategyPtr> sp;
|
||||
Shapes inputs_splittable;
|
||||
if (input_dims_ == 2) {
|
||||
inputs_splittable = {{1, 0}, {2, 0}};
|
||||
} else {
|
||||
inputs_splittable = {{1, 2, 0}, {1, 3, 0}};
|
||||
}
|
||||
if (GenerateStrategiesForDependentInputs(stage_id, inputs_shape_, inputs_splittable, &sp) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Generate strategies for dependent inputs() failed.";
|
||||
}
|
||||
return sp;
|
||||
}
|
||||
|
||||
void CdistInfo::ReComputeBatchSplitFlagList() {
|
||||
size_t input_dims = inputs_shape_.at(0).at(0);
|
||||
if (input_dims == 3) {
|
||||
if (inputs_shape_[0][0] != 1) {
|
||||
split_flag_list_[0] = true;
|
||||
split_flag_list_[1] = true;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// if input_dims is 2, only one of them can be split
|
||||
if (inputs_shape_[0][0] != 1) {
|
||||
split_flag_list_[0] = true;
|
||||
} else if (inputs_shape_[1][0] != 1) {
|
||||
split_flag_list_[1] = true;
|
||||
}
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CDIST_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CDIST_INFO_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class CdistInfo : public OperatorInfo {
|
||||
public:
|
||||
CdistInfo(const std::string &name, const Shapes &input_shape, const Shapes &output_shape, const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, input_shape, output_shape, attrs, std::make_shared<CdistCost>()) {}
|
||||
~CdistInfo() override = default;
|
||||
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
|
||||
private:
|
||||
size_t input_dims_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CDIST_INFO_H_
|
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
#include "frontend/parallel/ops_info/inplace_add_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status InplaceAddInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||
return FAILED;
|
||||
}
|
||||
Strategys strategies = strategy->GetInputDim();
|
||||
auto x_strategy = strategies.at(0);
|
||||
auto input_v_strategy = strategies.at(1);
|
||||
if (x_strategy[0] != 1 || input_v_strategy[0] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The 1st dimension of x and input_v is not supported sharding, "
|
||||
<< "but got strategy " << StrategyToString(strategies);
|
||||
return FAILED;
|
||||
}
|
||||
if (x_strategy != input_v_strategy) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy of x and input_v must be the same, "
|
||||
<< "but got strategy " << StrategyToString(strategies);
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status InplaceAddInfo::InferDevMatrixShape() {
|
||||
dev_matrix_shape_.clear();
|
||||
|
||||
auto x_strategy = strategy_->GetInputDim().at(0);
|
||||
dev_matrix_shape_.assign(x_strategy.begin() + 1, x_strategy.end());
|
||||
|
||||
MS_LOG(INFO) << name_ << ": dev matrix: " << ShapeToString(dev_matrix_shape_);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status InplaceAddInfo::InferTensorMap() {
|
||||
inputs_tensor_map_.clear();
|
||||
outputs_tensor_map_.clear();
|
||||
|
||||
Shape tensor_map = {-1};
|
||||
size_t dev_size = dev_matrix_shape_.size();
|
||||
if (repeated_calc_num_ > 1 && repeated_num_in_dev_matrix_right_) {
|
||||
--dev_size;
|
||||
}
|
||||
size_t start = 0;
|
||||
if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
|
||||
++start;
|
||||
}
|
||||
for (size_t i = start; i < dev_size; ++i) {
|
||||
tensor_map.push_back(dev_size - i - 1);
|
||||
}
|
||||
inputs_tensor_map_.push_back(tensor_map);
|
||||
inputs_tensor_map_.push_back(tensor_map);
|
||||
(void)outputs_tensor_map_.emplace_back(std::move(tensor_map));
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> InplaceAddInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
Shapes splittable_inputs = inputs_shape_;
|
||||
for (size_t i = 0; i < splittable_inputs.size(); ++i) {
|
||||
for (size_t j = 0; j < splittable_inputs[i].size(); ++j) {
|
||||
splittable_inputs[i][j] = j;
|
||||
}
|
||||
}
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
if (GenerateStrategiesForDependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Generate strategies for dependent inputs() failed.";
|
||||
}
|
||||
return sp_vector;
|
||||
}
|
||||
|
||||
void InplaceAddInfo::ReComputeBatchSplitFlagList() {
|
||||
split_flag_list_[0] = false;
|
||||
split_flag_list_[1] = false;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_ADD_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_ADD_INFO_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class InplaceAddInfo : public OperatorInfo {
|
||||
public:
|
||||
InplaceAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<InplaceAddCost>()) {}
|
||||
~InplaceAddInfo() override = default;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
};
|
||||
|
||||
class InplaceSubInfo : public InplaceAddInfo {
|
||||
public:
|
||||
InplaceSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: InplaceAddInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~InplaceSubInfo() = default;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_ADD_INFO_H_
|
|
@ -22,6 +22,7 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/tensor.h"
|
||||
|
@ -1415,6 +1416,68 @@ Status GenerateStrategiesForIndependentInputs(int64_t stage_id, const Shapes &in
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
// 'splittable_inputs' has the same dimensions as 'inputs_shape_'. '0' in 'splittable_inputs' means that
|
||||
// the corresponding dimension is unsplittable, otherwise means that the corresponding dimension is splittable.
|
||||
// In particular, if the same dimensions exist in 'splittable_inputs',
|
||||
// the corresponding dimensions in the strategy are the same.
|
||||
// 'sp' is the result of partitions.
|
||||
Status GenerateStrategiesForDependentInputs(int64_t stage_id, const Shapes &inputs_shape,
|
||||
const Shapes &splittable_inputs, std::vector<StrategyPtr> *sp) {
|
||||
if (inputs_shape.size() != splittable_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Size of inputs_shape and splittable_inputs are not equal.";
|
||||
}
|
||||
|
||||
std::unordered_map<int64_t, int64_t> mp;
|
||||
for (size_t i = 0; i < inputs_shape.size(); ++i) {
|
||||
auto input_shape = inputs_shape[i];
|
||||
auto splittable_input = splittable_inputs[i];
|
||||
for (size_t j = 0; j < input_shape.size(); ++j) {
|
||||
int64_t indice = splittable_input[j];
|
||||
int64_t shape = input_shape[j];
|
||||
if (splittable_input[j] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (mp.find(indice) == mp.end()) {
|
||||
mp[indice] = shape;
|
||||
} else {
|
||||
mp[indice] = std::gcd(mp[indice], shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<int64_t, size_t> indices_mp;
|
||||
Shape tmp_input_shape;
|
||||
Shapes tmp_splittable_inputs = {Shape(mp.size(), 1)};
|
||||
|
||||
for (const auto &item : mp) {
|
||||
indices_mp[item.first] = tmp_input_shape.size();
|
||||
tmp_input_shape.push_back(item.second);
|
||||
}
|
||||
Shapes tmp_inputs_shape = {tmp_input_shape};
|
||||
std::vector<StrategyPtr> tmp_sp_vector;
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, &tmp_sp_vector) !=
|
||||
SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::transform(tmp_sp_vector.begin(), tmp_sp_vector.end(), std::back_inserter(*sp),
|
||||
[stage_id, &indices_mp, &splittable_inputs](const StrategyPtr &sp) {
|
||||
auto tmp_strategy = sp->GetInputDim().at(0);
|
||||
Strategys strategies(splittable_inputs);
|
||||
for (size_t i = 0; i < strategies.size(); ++i) {
|
||||
for (size_t j = 0; j < strategies[i].size(); ++j) {
|
||||
if (splittable_inputs[i][j] == 0) {
|
||||
strategies[i][j] = 1;
|
||||
} else {
|
||||
strategies[i][j] = tmp_strategy[indices_mp[splittable_inputs[i][j]]];
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_shared<Strategy>(stage_id, strategies);
|
||||
});
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// generate strategies for that have two inputs, and input0 or input1 maybe broadcast,
|
||||
// and the corresponding dimensions that are not broadcast are all relevant dimensions
|
||||
// such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
|
||||
|
|
|
@ -356,6 +356,9 @@ Status GenerateStrategiesForIndependentInputsBase(int64_t stage_id, size_t dev_n
|
|||
// generate strategies for that all inputs' dimensions are independent, such as: ([a, b, c, d])
|
||||
Status GenerateStrategiesForIndependentInputs(int64_t stage_id, const Shapes &inputs_shape,
|
||||
const Shapes &splittable_inputs, std::vector<StrategyPtr> *sp_vector);
|
||||
// generate strategies for that inputs' dimension maybe dependent
|
||||
Status GenerateStrategiesForDependentInputs(int64_t stage_id, const Shapes &inputs_shape,
|
||||
const Shapes &splittable_inputs, std::vector<StrategyPtr> *sp);
|
||||
// generate strategies for that have two inputs, and input0 or input1 maybe broadcast,
|
||||
// and the corresponding dimensions that are not broadcast are all relevant dimensions
|
||||
// such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
|
||||
|
|
|
@ -67,5 +67,8 @@
|
|||
#include "frontend/parallel/ops_info/random_choice_with_mask_info.h"
|
||||
#include "frontend/parallel/ops_info/crop_and_resize_info.h"
|
||||
#include "frontend/parallel/ops_info/roi_align_info.h"
|
||||
#include "frontend/parallel/ops_info/addn_info.h"
|
||||
#include "frontend/parallel/ops_info/inplace_add_info.h"
|
||||
#include "frontend/parallel/ops_info/cdist_info.h"
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
||||
|
|
|
@ -57,6 +57,7 @@ constexpr size_t ACTIVATION_INPUTS_SIZE = 1;
|
|||
constexpr size_t ACTIVATION_OUTPUTS_SIZE = 1;
|
||||
constexpr size_t EXPANDDIMS_INPUT_SIZE = 2;
|
||||
constexpr size_t CUMSUM_INPUT_SIZE = 2;
|
||||
constexpr size_t CUM_OP_INPUT_SIZE = 2;
|
||||
constexpr size_t DROPOUT_DO_MASK_CNODE_INPUT_SIZE = 4;
|
||||
constexpr size_t DROPOUT_GEN_MASK_CNODE_INPUT_SIZE = 3;
|
||||
constexpr size_t DROPOUT_GEN_MASK_INDEX = 2;
|
||||
|
@ -447,7 +448,8 @@ constexpr char GATHERD[] = "GatherD";
|
|||
constexpr char DSD_MATMUL[] = "DSDMatmul";
|
||||
constexpr char RESIZE_BILINEAR[] = "ResizeBilinear";
|
||||
constexpr char RESIZE_NEAREST_NEIGHBOR[] = "ResizeNearestNeighbor";
|
||||
constexpr char CUMSUM[] = "CumSum";
|
||||
constexpr char CUM_SUM[] = "CumSum";
|
||||
constexpr char CUM_PROD[] = "CumProd";
|
||||
constexpr char BOUNDING_BOX_ENCODE[] = "BoundingBoxEncode";
|
||||
constexpr char IOU[] = "IOU";
|
||||
constexpr char RANDOM_CHOICE_WITH_MASK[] = "RandomChoiceWithMask";
|
||||
|
@ -455,6 +457,26 @@ constexpr char CROP_AND_RESIZE[] = "CropAndResize";
|
|||
constexpr char MASKED_FILL[] = "MaskedFill";
|
||||
constexpr char ROI_ALIGN[] = "ROIAlign";
|
||||
constexpr char SQUARE_SUM_ALL[] = "SquareSumAll";
|
||||
constexpr char IS_FINITE[] = "IsFinite";
|
||||
constexpr char RINT[] = "Rint";
|
||||
constexpr char HSHRINK[] = "HShrink";
|
||||
constexpr char HSIGMOID[] = "HSigmoid";
|
||||
constexpr char MISH[] = "Mish";
|
||||
constexpr char SELU[] = "SeLU";
|
||||
constexpr char SOFT_SHRINK[] = "SoftShrink";
|
||||
constexpr char XLOGY[] = "Xlogy";
|
||||
constexpr char XDIVY[] = "Xdivy";
|
||||
constexpr char BITWISE_AND[] = "BitwiseAnd";
|
||||
constexpr char BITWISE_OR[] = "BitwiseOr";
|
||||
constexpr char BITWISE_XOR[] = "BitwiseXor";
|
||||
constexpr char MUL_NO_NAN[] = "MulNoNan";
|
||||
constexpr char TRUNCATE_DIV[] = "TruncateDiv";
|
||||
constexpr char TRUNCATE_MOD[] = "TruncateMod";
|
||||
constexpr char INPLACE_ADD[] = "InplaceAdd";
|
||||
constexpr char INPLACE_SUB[] = "InplaceSub";
|
||||
constexpr char CDIST[] = "Cdist";
|
||||
constexpr char L2_LOSS[] = "L2Loss";
|
||||
constexpr char LERP[] = "Lerp";
|
||||
|
||||
// pipeline
|
||||
constexpr size_t PIPELINE_FUSTION_OFFSET = 100;
|
||||
|
|
|
@ -117,5 +117,7 @@ Status RandomChoiceWithMaskInfo::InitForCostModel(const StrategyPtr &strategy, c
|
|||
CheckGPUBackend();
|
||||
return OperatorInfo::InitForCostModel(strategy, out_strategy);
|
||||
}
|
||||
|
||||
void RandomChoiceWithMaskInfo::ReComputeBatchSplitFlagList() { split_flag_list_[0] = false; }
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,6 +40,7 @@ class RandomChoiceWithMaskInfo : public OperatorInfo {
|
|||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||
void ReplaceNodeInputOrAttrs() override;
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
|
|
|
@ -173,9 +173,13 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
|
||||
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
|
||||
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
|
||||
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR, CUMSUM, FAST_GELU, IOU,
|
||||
BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL,
|
||||
ARGMAX, ARGMIN, UNSORTED_SEGMENT_PROD, SQUARE_SUM_ALL};
|
||||
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR, FAST_GELU, IOU, BOUNDING_BOX_ENCODE,
|
||||
RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL, ARGMAX, ARGMIN,
|
||||
UNSORTED_SEGMENT_PROD, SQUARE_SUM_ALL, MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR,
|
||||
RESIZE_NEAREST_NEIGHBOR, CUM_SUM, FAST_GELU, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE,
|
||||
ROI_ALIGN, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH, SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND,
|
||||
BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, L2_LOSS, LERP, ADDN,
|
||||
CDIST};
|
||||
// clang-format on
|
||||
|
||||
auto iter = splittable_op.find(op_name);
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from parallel.utils.utils import compile_net
|
||||
|
||||
x_ = (Tensor(np.random.normal(size=[8, 8, 8])),
|
||||
Tensor(np.random.normal(size=[8, 8, 8])),
|
||||
Tensor(np.random.normal(size=[8, 8, 8])))
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, strategy=None):
|
||||
super(Net, self).__init__()
|
||||
self.addn = P.AddN().shard(strategy)
|
||||
|
||||
def construct(self, x):
|
||||
return self.addn(x)
|
||||
|
||||
|
||||
def test_addn_auto_parallel():
|
||||
"""
|
||||
Feature: test AddN auto parallel
|
||||
Description: auto parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net()
|
||||
compile_net(net, x_)
|
||||
|
||||
|
||||
def test_addn_model_parallel():
|
||||
"""
|
||||
Feature: test AddN model parallel
|
||||
Description: model parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((2, 2, 2), (2, 2, 2), (2, 2, 2))
|
||||
net = Net(strategy)
|
||||
compile_net(net, x_)
|
||||
|
||||
|
||||
def test_addn_strategy_error():
|
||||
"""
|
||||
Feature: test invalid strategy for AddN
|
||||
Description: illegal strategy
|
||||
Expectation: raise RuntimeError
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((2, 2, 2), (2, 2, 2), (2, 2, 1))
|
||||
net = Net(strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, x_)
|
|
@ -884,3 +884,229 @@ def test_assign():
|
|||
net = SubGradWrap(SubNetWithLoss(Net()))
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
compile_sub_net(net, x)
|
||||
|
||||
|
||||
def test_matmul_bitwise_and_broadcast():
|
||||
"""
|
||||
Feature: distribute operator BitwiseAnd in auto parallel.
|
||||
Description: mul-BitwiseAnd net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.bitwise_and = P.BitwiseAnd().shard(strategy1)
|
||||
self.matmul = P.MatMul().shard(strategy2)
|
||||
|
||||
|
||||
def construct(self, x, y, z):
|
||||
out = self.bitwise_and(x, y)
|
||||
out = self.matmul(out, z)
|
||||
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 1), (1, 4))
|
||||
strategy2 = ((1, 4), (4, 2))
|
||||
net = Net(strategy1, strategy2)
|
||||
|
||||
x = Tensor(np.ones([64, 1]), dtype=ms.int32)
|
||||
y = Tensor(np.ones([1, 64]), dtype=ms.int32)
|
||||
z = Tensor(np.ones([64, 32]), dtype=ms.int32)
|
||||
compile_net(net, x, y, z)
|
||||
|
||||
|
||||
def test_matmul_bitwise_or_broadcast():
|
||||
"""
|
||||
Feature: distribute operator BitwiseOr in auto parallel.
|
||||
Description: mul-BitwiseOr net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.bitwise_or = P.BitwiseOr().shard(strategy1)
|
||||
self.matmul = P.MatMul().shard(strategy2)
|
||||
|
||||
def construct(self, x, y, z):
|
||||
out = self.bitwise_or(x, y)
|
||||
out = self.matmul(out, z)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 1), (1, 4))
|
||||
strategy2 = ((1, 4), (4, 2))
|
||||
net = Net(strategy1, strategy2)
|
||||
|
||||
x = Tensor(np.ones([64, 1]), dtype=ms.int32)
|
||||
y = Tensor(np.ones([1, 64]), dtype=ms.int32)
|
||||
z = Tensor(np.ones([64, 32]), dtype=ms.int32)
|
||||
compile_net(net, x, y, z)
|
||||
|
||||
|
||||
def test_matmul_bitwise_xor_broadcast():
|
||||
"""
|
||||
Feature: distribute operator BitwiseXor in auto parallel.
|
||||
Description: mul-BitwiseXor net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.bitwise_xor = P.BitwiseXor().shard(strategy1)
|
||||
self.matmul = P.MatMul().shard(strategy2)
|
||||
|
||||
def construct(self, x, y, z):
|
||||
out = self.bitwise_xor(x, y)
|
||||
out = self.matmul(out, z)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 1), (1, 4))
|
||||
strategy2 = ((1, 4), (4, 2))
|
||||
net = Net(strategy1, strategy2)
|
||||
|
||||
x = Tensor(np.ones([64, 1]), dtype=ms.int32)
|
||||
y = Tensor(np.ones([1, 64]), dtype=ms.int32)
|
||||
z = Tensor(np.ones([64, 32]), dtype=ms.int32)
|
||||
compile_net(net, x, y, z)
|
||||
|
||||
|
||||
def test_matmul_mul_no_nan_broadcast():
|
||||
"""
|
||||
Feature: distribute operator MulNoNan in auto parallel.
|
||||
Description: mul-MulNoNan net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.mul_no_nan = P.MulNoNan().shard(strategy2)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul(x, y)
|
||||
out = self.mul_no_nan(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 4), (4, 1))
|
||||
strategy2 = ((4, 1), (1, 2))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_truncate_div_broadcast():
|
||||
"""
|
||||
Feature: distribute operator TruncateDiv in auto parallel.
|
||||
Description: mul-TruncateDiv net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.truncate_div = P.TruncateDiv().shard(strategy2)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul(x, y)
|
||||
out = self.truncate_div(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 4), (4, 1))
|
||||
strategy2 = ((4, 1), (1, 2))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_truncate_mod_broadcast():
|
||||
"""
|
||||
Feature: distribute operator TruncateMod in auto parallel.
|
||||
Description: mul-TruncateMod net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.truncate_mod = P.TruncateMod().shard(strategy2)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul(x, y)
|
||||
out = self.truncate_mod(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 4), (4, 1))
|
||||
strategy2 = ((4, 1), (1, 2))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_xdivy_broadcast():
|
||||
"""
|
||||
Feature: distribute operator Xdivy in auto parallel.
|
||||
Description: mul-Xdivy net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.xdivy = P.Xdivy().shard(strategy2)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul(x, y)
|
||||
out = self.xdivy(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 4), (4, 1))
|
||||
strategy2 = ((4, 1), (1, 2))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_xlogy_broadcast():
|
||||
"""
|
||||
Feature: distribute operator Xlogy in auto parallel.
|
||||
Description: mul-Xlogy net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.xlogy = P.Xlogy().shard(strategy2)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul(x, y)
|
||||
out = self.xlogy(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((2, 4), (4, 1))
|
||||
strategy2 = ((4, 1), (1, 2))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
|
||||
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops as ops
|
||||
|
||||
from parallel.utils.utils import compile_net
|
||||
|
||||
B = 8
|
||||
P = 8
|
||||
R = 8
|
||||
M = 2
|
||||
|
||||
input_x_2d_ = Tensor(np.random.normal(size=[P, M]).astype(np.float32))
|
||||
input_y_2d_ = Tensor(np.random.normal(size=[R, M]).astype(np.float32))
|
||||
input_x_3d_ = Tensor(np.random.normal(size=[B, P, M]).astype(np.float32))
|
||||
input_y_3d_ = Tensor(np.random.normal(size=[B, R, M]).astype(np.float32))
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, strategy=None):
|
||||
super(Net, self).__init__()
|
||||
self.cdist = ops.Cdist().shard(strategy)
|
||||
|
||||
def construct(self, *inputs):
|
||||
output = self.cdist(*inputs)
|
||||
return output
|
||||
|
||||
|
||||
def test_cdist_2d_auto_parallel():
|
||||
"""
|
||||
Feature: test Cdist-2d in parallel
|
||||
Description: auto parallel with 2d inputs
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net()
|
||||
compile_net(net, input_x_2d_, input_y_2d_)
|
||||
|
||||
|
||||
def test_cdist_2d_data_parallel():
|
||||
"""
|
||||
Feature: test Cdist-2d in parallel
|
||||
Description: data parallel with 2d inputs
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((4, 1), (2, 1))
|
||||
net = Net(strategy)
|
||||
compile_net(net, input_x_2d_, input_y_2d_)
|
||||
|
||||
|
||||
def test_cdist_2d_data_parallel_with_repeated_cal():
|
||||
"""
|
||||
Feature: test Cdist-2d in parallel with repeated calculation
|
||||
Description: data parallel with 2d inputs
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((2, 1), (2, 1))
|
||||
net = Net(strategy)
|
||||
compile_net(net, input_x_2d_, input_y_2d_)
|
||||
|
||||
|
||||
def test_cdist_2d_strategy_error():
|
||||
"""
|
||||
Feature: test invalid strategy for Cdist 2d
|
||||
Description: illegal strategy with 2d inputs
|
||||
Expectation: raise RuntimeError
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((2, 2), (2, 1))
|
||||
net = Net(strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, input_x_2d_, input_y_2d_)
|
||||
|
||||
|
||||
def test_cdist_3d_auto_parallel():
|
||||
"""
|
||||
Feature: test Cdist-3d in parallel
|
||||
Description: auto parallel with 3d inputs
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net()
|
||||
compile_net(net, input_x_3d_, input_y_3d_)
|
||||
|
||||
|
||||
def test_cdist_3d_data_parallel():
|
||||
"""
|
||||
Feature: test Cdist-3d in parallel
|
||||
Description: data parallel with 3d inputs
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((2, 2, 1), (2, 2, 1))
|
||||
net = Net(strategy)
|
||||
compile_net(net, input_x_3d_, input_y_3d_)
|
||||
|
||||
|
||||
def test_cdist_3d_data_parallel_with_repeated_cal():
|
||||
"""
|
||||
Feature: test Cdist-3d in parallel with repeated calculation
|
||||
Description: data parallel with 3d inputs
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((2, 2, 1), (2, 1, 1))
|
||||
net = Net(strategy)
|
||||
compile_net(net, input_x_3d_, input_y_3d_)
|
||||
|
||||
|
||||
def test_cdist_3d_strategy_error():
|
||||
"""
|
||||
Feature: test invalid strategy for Cdist 3d
|
||||
Description: illegal strategy with 3d inputs
|
||||
Expectation: raise RuntimeError
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((2, 2, 1), (1, 2, 1))
|
||||
net = Net(strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, input_x_3d_, input_y_3d_)
|
|
@ -954,3 +954,222 @@ def test_mul_two_cast():
|
|||
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_hshrink():
|
||||
"""
|
||||
Feature: distribute operator HShrink in auto parallel.
|
||||
Description: matmul-hshrink-matmul net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.hshrink = P.HShrink().shard(strategy2)
|
||||
self.matmul2 = P.MatMul().shard(strategy1)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul(x, y)
|
||||
out = self.hshrink(out)
|
||||
out = self.matmul2(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2), (2, 2))
|
||||
strategy2 = ((4, 2),)
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x = Tensor(np.random.uniform(-5, 5, size=(128, 32)), dtype=ms.float32)
|
||||
y = Tensor(np.random.uniform(-5, 5, size=(32, 64)), dtype=ms.float32)
|
||||
b = Tensor(np.random.uniform(-5, 5, size=(64, 64)), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_hsigmoid():
|
||||
"""
|
||||
Feature: distribute operator HSigmoid in auto parallel.
|
||||
Description: matmul-hsigmoid-matmul net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.hsigmoid = P.HSigmoid().shard(strategy2)
|
||||
self.matmul2 = P.MatMul().shard(strategy1)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul(x, y)
|
||||
out = self.hsigmoid(out)
|
||||
out = self.matmul2(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2), (2, 2))
|
||||
strategy2 = ((4, 2),)
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x = Tensor(np.random.uniform(-5, 5, size=(128, 32)), dtype=ms.float32)
|
||||
y = Tensor(np.random.uniform(-5, 5, size=(32, 64)), dtype=ms.float32)
|
||||
b = Tensor(np.random.uniform(-5, 5, size=(64, 64)), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_is_finite():
|
||||
"""
|
||||
Feature: distribute operator IsFinite in auto parallel.
|
||||
Description: matmul-is_finite-cast-matmul net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.is_finite = P.IsFinite().shard(strategy2)
|
||||
self.cast = P.Cast().shard(strategy2)
|
||||
self.matmul2 = P.MatMul().shard(strategy1)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul(x, y)
|
||||
out = self.is_finite(out)
|
||||
out = self.cast(out, ms.float32)
|
||||
out = self.matmul2(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2), (2, 2))
|
||||
strategy2 = ((4, 2),)
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x = Tensor(np.random.uniform(-5, 5, size=(128, 32)), dtype=ms.float32)
|
||||
y = Tensor(np.random.uniform(-5, 5, size=(32, 64)), dtype=ms.float32)
|
||||
b = Tensor(np.random.uniform(-5, 5, size=(64, 64)), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_mish():
|
||||
"""
|
||||
Feature: distribute operator Mish in auto parallel.
|
||||
Description: matmul-mish-matmul net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.mish = P.Mish().shard(strategy2)
|
||||
self.matmul2 = P.MatMul().shard(strategy1)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul(x, y)
|
||||
out = self.mish(out)
|
||||
out = self.matmul2(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2), (2, 2))
|
||||
strategy2 = ((4, 2),)
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x = Tensor(np.random.uniform(-5, 5, size=(128, 32)), dtype=ms.float32)
|
||||
y = Tensor(np.random.uniform(-5, 5, size=(32, 64)), dtype=ms.float32)
|
||||
b = Tensor(np.random.uniform(-5, 5, size=(64, 64)), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_rint():
|
||||
"""
|
||||
Feature: distribute operator Rint in auto parallel.
|
||||
Description: matmul-rint-matmul net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.rint = P.Rint().shard(strategy2)
|
||||
self.matmul2 = P.MatMul().shard(strategy1)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul(x, y)
|
||||
out = self.rint(out)
|
||||
out = self.matmul2(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2), (2, 2))
|
||||
strategy2 = ((4, 2),)
|
||||
net = Net(strategy1, strategy2)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x = Tensor(np.random.uniform(-5, 5, size=(128, 32)), dtype=ms.float32)
|
||||
y = Tensor(np.random.uniform(-5, 5, size=(32, 64)), dtype=ms.float32)
|
||||
b = Tensor(np.random.uniform(-5, 5, size=(64, 64)), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
def test_selu_mish():
|
||||
"""
|
||||
Feature: distribute operator SeLU in auto parallel.
|
||||
Description: matmul-selu-matmul net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.selu = P.SeLU().shard(strategy2)
|
||||
self.matmul2 = P.MatMul().shard(strategy1)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul(x, y)
|
||||
out = self.selu(out)
|
||||
out = self.matmul2(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2), (2, 2))
|
||||
strategy2 = ((4, 2),)
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x = Tensor(np.random.uniform(-5, 5, size=(128, 32)), dtype=ms.float32)
|
||||
y = Tensor(np.random.uniform(-5, 5, size=(32, 64)), dtype=ms.float32)
|
||||
b = Tensor(np.random.uniform(-5, 5, size=(64, 64)), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_soft_shrink():
|
||||
"""
|
||||
Feature: distribute operator SoftShrink in auto parallel.
|
||||
Description: matmul-soft_shrink-matmul net with strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.soft_shrink = P.SoftShrink().shard(strategy2)
|
||||
self.matmul2 = P.MatMul().shard(strategy1)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul(x, y)
|
||||
out = self.soft_shrink(out)
|
||||
out = self.matmul2(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2), (2, 2))
|
||||
strategy2 = ((4, 2),)
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x = Tensor(np.random.uniform(-5, 5, size=(128, 32)), dtype=ms.float32)
|
||||
y = Tensor(np.random.uniform(-5, 5, size=(32, 64)), dtype=ms.float32)
|
||||
b = Tensor(np.random.uniform(-5, 5, size=(64, 64)), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from parallel.utils.utils import compile_net
|
||||
|
||||
x_ = Tensor(np.random.normal(size=[32, 8, 8]).astype(np.float32))
|
||||
input_v_ = Tensor(np.random.normal(size=[16, 8, 8]).astype(np.float32))
|
||||
indices_ = tuple(range(16))
|
||||
|
||||
|
||||
class InplaceAddNet(Cell):
|
||||
def __init__(self, indices, strategy=None):
|
||||
super(InplaceAddNet, self).__init__()
|
||||
self.inplace_add = P.InplaceAdd(indices).shard(strategy)
|
||||
|
||||
def construct(self, x, input_v):
|
||||
return self.inplace_add(x, input_v)
|
||||
|
||||
|
||||
class InplaceSubNet(Cell):
|
||||
def __init__(self, indices, strategy=None):
|
||||
super(InplaceSubNet, self).__init__()
|
||||
self.inplace_sub = P.InplaceSub(indices).shard(strategy)
|
||||
|
||||
def construct(self, x, input_v):
|
||||
return self.inplace_sub(x, input_v)
|
||||
|
||||
|
||||
def test_inplace_add_auto_parallel():
|
||||
"""
|
||||
Feature: test InplaceAdd auto parallel
|
||||
Description: auto parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = InplaceAddNet(indices_)
|
||||
compile_net(net, x_, input_v_)
|
||||
|
||||
|
||||
def test_inplace_add_model_parallel():
|
||||
"""
|
||||
Feature: test InplaceAdd model parallel
|
||||
Description: model parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 4, 2), (1, 4, 2))
|
||||
net = InplaceAddNet(indices_, strategy)
|
||||
compile_net(net, x_, input_v_)
|
||||
|
||||
|
||||
def test_inplace_add_model_parallel_with_repeated_cal():
|
||||
"""
|
||||
Feature: test InplaceAdd model parallel with repeated calculation
|
||||
Description: model parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 2, 2), (1, 2, 2))
|
||||
net = InplaceAddNet(indices_, strategy)
|
||||
compile_net(net, x_, input_v_)
|
||||
|
||||
|
||||
def test_inplace_add_strategy_error():
|
||||
"""
|
||||
Feature: test invalid strategy for InplaceAdd
|
||||
Description: illegal strategy
|
||||
Expectation: raise RuntimeError
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 4, 2), (1, 2, 4))
|
||||
net = InplaceAddNet(indices_, strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, x_, input_v_)
|
||||
|
||||
|
||||
def test_inplace_sub_auto_parallel():
|
||||
"""
|
||||
Feature: test InplaceSub auto parallel
|
||||
Description: auto parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = InplaceSubNet(indices_)
|
||||
compile_net(net, x_, input_v_)
|
||||
|
||||
|
||||
def test_inplace_sub_model_parallel():
|
||||
"""
|
||||
Feature: test InplaceSub model parallel
|
||||
Description: model parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 4, 2), (1, 4, 2))
|
||||
net = InplaceSubNet(indices_, strategy)
|
||||
compile_net(net, x_, input_v_)
|
||||
|
||||
|
||||
def test_inplace_sub_model_parallel_with_repeated_cal():
|
||||
"""
|
||||
Feature: test InplaceSub model parallel with repeated calculation
|
||||
Description: model parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 2, 2), (1, 2, 2))
|
||||
net = InplaceSubNet(indices_, strategy)
|
||||
compile_net(net, x_, input_v_)
|
||||
|
||||
|
||||
def test_inplace_sub_strategy_error():
|
||||
"""
|
||||
Feature: test invalid strategy for InplaceSub
|
||||
Description: illegal strategy
|
||||
Expectation: raise RuntimeError
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 4, 2), (1, 2, 4))
|
||||
net = InplaceSubNet(indices_, strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, x_, input_v_)
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from parallel.utils.utils import compile_net
|
||||
|
||||
x_ = Tensor(np.random.normal(size=[32, 8, 8]).astype(np.float32))
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, strategy=None):
|
||||
super(Net, self).__init__()
|
||||
self.l2_loss = P.L2Loss().shard(strategy)
|
||||
|
||||
def construct(self, x):
|
||||
return self.l2_loss(x)
|
||||
|
||||
|
||||
def test_l2_loss_auto_parallel():
|
||||
"""
|
||||
Feature: test L2Loss auto parallel
|
||||
Description: auto parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net()
|
||||
compile_net(net, x_)
|
||||
|
||||
|
||||
def test_l2_loss_model_parallel():
|
||||
"""
|
||||
Feature: test L2Loss model parallel
|
||||
Description: model parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((2, 2, 2),)
|
||||
net = Net(strategy)
|
||||
compile_net(net, x_)
|
|
@ -0,0 +1,143 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops as ops
|
||||
|
||||
from parallel.utils.utils import compile_net
|
||||
|
||||
input_start_ = Tensor(np.random.normal(size=[8, 8, 8]).astype(np.float32))
|
||||
input_end_ = Tensor(np.random.normal(size=[8]).astype(np.float32))
|
||||
input_weight_tensor_ = Tensor(np.random.normal(size=[8, 8]).astype(np.float32))
|
||||
input_weight_float_ = 0.5
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, strategy=None):
|
||||
super(Net, self).__init__()
|
||||
self.lerp = ops.Lerp().shard(strategy)
|
||||
|
||||
def construct(self, *inputs):
|
||||
output = self.lerp(*inputs)
|
||||
return output
|
||||
|
||||
|
||||
def test_lerp_auto_parallel_with_weight_tensor():
|
||||
"""
|
||||
Feature: test Lerp auto parallel
|
||||
Description: auto parallel when 'weight' is tensor
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net()
|
||||
compile_net(net, input_start_, input_end_, input_weight_tensor_)
|
||||
|
||||
|
||||
def test_lerp_auto_parallel_with_weight_float():
|
||||
"""
|
||||
Feature: test Lerp auto parallel
|
||||
Description: auto parallel when 'weight' is float
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net()
|
||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
||||
|
||||
|
||||
def test_lerp_model_parallel_with_weight_tensor():
|
||||
"""
|
||||
Feature: test Lerp model parallel
|
||||
Description: model parallel when 'weight' is tensor
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((2, 2, 2), (2,), (2, 2))
|
||||
net = Net(strategy)
|
||||
compile_net(net, input_start_, input_end_, input_weight_tensor_)
|
||||
|
||||
|
||||
def test_lerp_model_parallel_with_weight_float():
|
||||
"""
|
||||
Feature: test Lerp model parallel
|
||||
Description: model parallel when 'weight' is float
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((2, 2, 2), (2,))
|
||||
net = Net(strategy)
|
||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
||||
|
||||
|
||||
def test_lerp_model_parallel_repeated_cal_with_weight_tensor():
|
||||
"""
|
||||
Feature: test Lerp model parallel with repeated calculation
|
||||
Description: model parallel when 'weight' is tensor
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 2, 2), (2,), (2, 2))
|
||||
net = Net(strategy)
|
||||
compile_net(net, input_start_, input_end_, input_weight_tensor_)
|
||||
|
||||
|
||||
def test_lerp_model_parallel_repeated_cal_with_weight_float():
|
||||
"""
|
||||
Feature: test Lerp model parallel with repeated calculation
|
||||
Description: model parallel when 'weight' is float
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 2, 2), (2,))
|
||||
net = Net(strategy)
|
||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
||||
|
||||
|
||||
def test_lerp_data_parallel_with_weight_tensor():
|
||||
"""
|
||||
Feature: test Lerp data parallel
|
||||
Description: data parallel when 'weight' is tensor
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((8, 1, 1), (1,), (1, 1))
|
||||
net = Net(strategy)
|
||||
compile_net(net, input_start_, input_end_, input_weight_tensor_)
|
||||
|
||||
|
||||
def test_lerp_data_parallel_with_weight_float():
|
||||
"""
|
||||
Feature: test Lerp data parallel
|
||||
Description: data parallel when 'weight' is float
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((8, 1, 1), (1,))
|
||||
net = Net(strategy)
|
||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
||||
|
||||
|
||||
def test_lerp_strategy_error_with_weight_tensor():
|
||||
"""
|
||||
Feature: test invalid strategy for Lerp
|
||||
Description: illegal strategy when 'weight' is tensor
|
||||
Expectation: raise RuntimeError
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((4, 2, 1), (1,), (1, 2))
|
||||
net = Net(strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, input_start_, input_end_, input_weight_tensor_)
|
||||
|
||||
|
||||
def test_lerp_strategy_error_with_weight_float():
|
||||
"""
|
||||
Feature: test invalid strategy for Lerp
|
||||
Description: illegal strategy when 'weight' is float
|
||||
Expectation: raise RuntimeError
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((4, 1, 2), (1,))
|
||||
net = Net(strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
return grad_all(self.network)(x, y)
|
||||
|
||||
|
||||
def compile_net(net, x, y):
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_cumprod_semi():
|
||||
"""
|
||||
Feature: CumProd operatorInfo in parallel.
|
||||
Description: MatMul->CumProd
|
||||
Expectation: Currently, CumProd does not support the axis dimension split. compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul().shard(((16, 1), (1, 1)))
|
||||
self.cumprod = P.CumProd().shard(((16, 1),))
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul1(x, y)
|
||||
out = self.cumprod(out, 0)
|
||||
return out
|
||||
|
||||
size = 16
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, x, y)
|
||||
|
||||
|
||||
def test_cumprod_semi2():
|
||||
"""
|
||||
Feature: CumProd operatorInfo in parallel.
|
||||
Description: MatMul->CumProd
|
||||
Expectation: Compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul().shard(((16, 1), (1, 1)))
|
||||
self.cumprod = P.CumProd().shard(((1, 16),))
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul1(x, y)
|
||||
out = self.cumprod(out, 0)
|
||||
return out
|
||||
|
||||
size = 16
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
compile_net(net, x, y)
|
||||
|
||||
|
||||
def test_cumprod_semi3():
|
||||
"""
|
||||
Feature: CumProd operatorInfo in parallel.
|
||||
Description: MatMul->CumProd
|
||||
Expectation: Compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul().shard(((16, 1), (1, 1)))
|
||||
self.cumprod = P.CumProd().shard(((2, 1),))
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul1(x, y)
|
||||
out = self.cumprod(out, 1)
|
||||
return out
|
||||
|
||||
size = 16
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
compile_net(net, x, y)
|
||||
|
||||
|
||||
def test_cumprod_auto():
|
||||
"""
|
||||
Feature: CumProd operatorInfo in parallel.
|
||||
Description: MatMul->CumProd
|
||||
Expectation: Compile done without error.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul().shard(((16, 1), (1, 1)))
|
||||
self.cumprod = P.CumProd()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul1(x, y)
|
||||
out = self.cumprod(out, -1)
|
||||
return out
|
||||
|
||||
size = 16
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
compile_net(net, x, y)
|
|
@ -11,6 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
|
||||
|
||||
|
@ -115,3 +118,11 @@ class ParallelValidator:
|
|||
if graph_name not in self._graph_info_dict.keys():
|
||||
raise ValueError("{} is not exist".format(graph_name))
|
||||
return self._graph_info_dict[graph_name]
|
||||
|
||||
|
||||
def compile_net(net: Cell, *inputs, auto_parallel_mode=False):
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
phase, _ = _cell_graph_executor.compile(net, *inputs, auto_parallel_mode=auto_parallel_mode)
|
||||
context.reset_auto_parallel_context()
|
||||
return phase
|
||||
|
|
Loading…
Reference in New Issue