forked from mindspore-Ecosystem/mindspore
implement parallel Pack
This commit is contained in:
parent
9475f9a19a
commit
6066b16838
|
@ -199,6 +199,8 @@ class SoftmaxCost : public OperatorCost {
|
||||||
using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>;
|
using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>;
|
||||||
using TileCost = SoftmaxCost;
|
using TileCost = SoftmaxCost;
|
||||||
using TileCostPtr = std::shared_ptr<TileCost>;
|
using TileCostPtr = std::shared_ptr<TileCost>;
|
||||||
|
using PackCost = TileCost;
|
||||||
|
using PackCostPtr = std::shared_ptr<PackCost>;
|
||||||
using ConcatCost = TileCost;
|
using ConcatCost = TileCost;
|
||||||
using ConcatCostPtr = std::shared_ptr<ConcatCost>;
|
using ConcatCostPtr = std::shared_ptr<ConcatCost>;
|
||||||
using SplitCost = TileCost;
|
using SplitCost = TileCost;
|
||||||
|
|
|
@ -178,6 +178,7 @@ REGISTER(EmbeddingLookupInfo);
|
||||||
REGISTER(TileInfo);
|
REGISTER(TileInfo);
|
||||||
REGISTER(StridedSliceInfo);
|
REGISTER(StridedSliceInfo);
|
||||||
REGISTER(DropoutInfo);
|
REGISTER(DropoutInfo);
|
||||||
|
REGISTER(PackInfo);
|
||||||
REGISTER(ConcatInfo);
|
REGISTER(ConcatInfo);
|
||||||
REGISTER(SplitInfo);
|
REGISTER(SplitInfo);
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
|
|
|
@ -39,7 +39,6 @@ const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM,
|
||||||
TILE_SHAPE,
|
TILE_SHAPE,
|
||||||
TUPLE_DIV,
|
TUPLE_DIV,
|
||||||
TUPLE_TO_ARRAY,
|
TUPLE_TO_ARRAY,
|
||||||
MAKE_LIST,
|
|
||||||
MAKE_DICT,
|
MAKE_DICT,
|
||||||
MAKE_SLICE,
|
MAKE_SLICE,
|
||||||
MAKE_RECORD,
|
MAKE_RECORD,
|
||||||
|
|
|
@ -41,5 +41,6 @@
|
||||||
#include "frontend/parallel/ops_info/strided_slice_info.h"
|
#include "frontend/parallel/ops_info/strided_slice_info.h"
|
||||||
#include "frontend/parallel/ops_info/concat_info.h"
|
#include "frontend/parallel/ops_info/concat_info.h"
|
||||||
#include "frontend/parallel/ops_info/split_info.h"
|
#include "frontend/parallel/ops_info/split_info.h"
|
||||||
|
#include "frontend/parallel/ops_info/pack_info.h"
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
||||||
|
|
|
@ -0,0 +1,253 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "frontend/parallel/ops_info/pack_info.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "frontend/parallel/device_matrix.h"
|
||||||
|
#include "frontend/parallel/strategy.h"
|
||||||
|
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||||
|
#include "pipeline/jit/resource.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
Status PackInfo::GetAttrs() {
|
||||||
|
int axis = 0;
|
||||||
|
auto axis_iter = attrs_.find(AXIS);
|
||||||
|
if (axis_iter != attrs_.end()) {
|
||||||
|
MS_EXCEPTION_IF_NULL(axis_iter->second);
|
||||||
|
if (axis_iter->second->isa<Int32Imm>()) {
|
||||||
|
axis = axis_iter->second->cast<Int32ImmPtr>()->value();
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The value of axis is not int";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Can not find the axis attr";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputs_shape_.empty()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
int dim = SizeToInt(inputs_shape_[0].size());
|
||||||
|
|
||||||
|
if (axis < 0) {
|
||||||
|
axis = axis + dim;
|
||||||
|
}
|
||||||
|
axis_ = SizeToInt(axis);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PackInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
|
MS_EXCEPTION_IF_NULL(strategy);
|
||||||
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||||
|
for (size_t i = 0; i < stra.size(); ++i) {
|
||||||
|
auto strategy_ele = stra[i];
|
||||||
|
if (axis_ > strategy_ele.size()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The axis is out of range, the axis is " << axis_;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t j = 0; j < strategy_ele.size(); ++j) {
|
||||||
|
if (strategy_ele[j] != stra[0][j]) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The strategy of each input tensor must be equal";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PackInfo::InferDevMatrixShape() {
|
||||||
|
MS_EXCEPTION_IF_NULL(strategy_);
|
||||||
|
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||||
|
if (stra.empty()) {
|
||||||
|
MS_LOG(ERROR) << name_ << "The strategy is empty";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
dev_matrix_shape_ = stra[0];
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PackInfo::InferTensorMap() {
|
||||||
|
TensorMap in_tensor_map;
|
||||||
|
TensorMap out_tensor_map;
|
||||||
|
|
||||||
|
if (inputs_shape_.empty()) {
|
||||||
|
MS_LOG(ERROR) << name_ << "The inputs shape is empty";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t size = SizeToInt(inputs_shape_[0].size());
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
in_tensor_map.push_back(size - i - 1);
|
||||||
|
out_tensor_map.push_back(size - i - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||||
|
inputs_tensor_map_.push_back(in_tensor_map);
|
||||||
|
}
|
||||||
|
|
||||||
|
out_tensor_map.insert(out_tensor_map.begin() + axis_, MAP_NONE);
|
||||||
|
outputs_tensor_map_.push_back(out_tensor_map);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PackInfo::InferMirrorOps() {
|
||||||
|
mirror_ops_.clear();
|
||||||
|
if (inputs_tensor_map_.empty()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape input_tensor_map = inputs_tensor_map_[0];
|
||||||
|
std::vector<Group> group;
|
||||||
|
if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Create group for input failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (group.empty()) {
|
||||||
|
MS_LOG(INFO) << name_ << ": The mirror group is empty.";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
OperatorVector input_op;
|
||||||
|
input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
|
||||||
|
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||||
|
mirror_ops_.push_back(input_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PackInfo::InferTensorInfo() {
|
||||||
|
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid args";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorLayout input_layout, output_layout;
|
||||||
|
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||||
|
// infer tensor layout
|
||||||
|
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
TensorInfo input_tensor_info(input_layout);
|
||||||
|
inputs_tensor_info_.push_back(input_tensor_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
TensorInfo output_tensor_info(output_layout);
|
||||||
|
outputs_tensor_info_.push_back(output_tensor_info);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PackInfo::ReComputeBatchSplitFlagList() {
|
||||||
|
for (size_t i = 0; i < inputs_shape_.size(); i++) {
|
||||||
|
split_flag_list_[i] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PackInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||||
|
|
||||||
|
Status PackInfo::GenerateStrategies(int32_t stage_id) {
|
||||||
|
if (InferAttrs() != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Infer attrs failed";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (inputs_shape_.empty()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
Shape input_split;
|
||||||
|
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
|
||||||
|
input_split.push_back(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// to generate the first input's strategy
|
||||||
|
Shapes splittable_input = {input_split};
|
||||||
|
Shapes tmp_inputs_shape = {inputs_shape_[0]};
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> sp_vector;
|
||||||
|
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Generate strategies failed";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
// the others strategies are equal to the first input's strategy
|
||||||
|
for (auto &sp : sp_vector) {
|
||||||
|
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The strategy is null or empty";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
Strategys tmp_strategy;
|
||||||
|
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
||||||
|
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||||
|
tmp_strategy.push_back(first_input_strategy);
|
||||||
|
}
|
||||||
|
sp->ResetInputs(tmp_strategy);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t success = 0;
|
||||||
|
for (auto &sp : sp_vector) {
|
||||||
|
PrintStrategy(sp);
|
||||||
|
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
||||||
|
success++;
|
||||||
|
MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
|
||||||
|
PrintStrategy(sp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PackInfo::Init(const StrategyPtr &strategy) {
|
||||||
|
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << name_ << ": Init success.";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PackInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||||
|
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PACK_INFO_H_
|
||||||
|
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PACK_INFO_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ir/value.h"
|
||||||
|
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||||
|
#include "frontend/parallel/ops_info/operator_info.h"
|
||||||
|
#include "frontend/parallel/strategy.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
class PackInfo : public OperatorInfo {
|
||||||
|
public:
|
||||||
|
PackInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<PackCost>(false)) {}
|
||||||
|
~PackInfo() override = default;
|
||||||
|
|
||||||
|
Status Init(const StrategyPtr &strategy) override;
|
||||||
|
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||||
|
Status GenerateStrategies(int32_t) override;
|
||||||
|
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||||
|
void ReComputeBatchSplitFlagList() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status GetAttrs() override;
|
||||||
|
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||||
|
Status InferMirrorOps() override;
|
||||||
|
Status InferForwardCommunication() override { return SUCCESS; }
|
||||||
|
Status InferTensorInfo() override;
|
||||||
|
Status InferDevMatrixShape() override;
|
||||||
|
Status InferTensorMap() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t axis_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
using PackInfoPtr = std::shared_ptr<PackInfo>;
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_PACK_INFO_H_
|
|
@ -116,7 +116,8 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
||||||
std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
|
std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
|
||||||
std::vector<bool> is_parameter;
|
std::vector<bool> is_parameter;
|
||||||
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
||||||
if ((node_inputs.size() == 2) && AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE)) {
|
if ((node_inputs.size() == 2) &&
|
||||||
|
(AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
|
||||||
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
|
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
|
||||||
}
|
}
|
||||||
for (size_t i = 1; i < node_inputs.size(); ++i) {
|
for (size_t i = 1; i < node_inputs.size(); ++i) {
|
||||||
|
@ -193,7 +194,8 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
|
||||||
std::vector<size_t> inputs_type_len;
|
std::vector<size_t> inputs_type_len;
|
||||||
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
||||||
|
|
||||||
if ((node_inputs.size() == 2) && AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE)) {
|
if ((node_inputs.size() == 2) &&
|
||||||
|
(AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
|
||||||
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
|
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -259,7 +261,7 @@ bool IsSplittableOperator(const std::string &op_name) {
|
||||||
{MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
|
{MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
|
||||||
FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
|
FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
|
||||||
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
|
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
|
||||||
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP,
|
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, PACK,
|
||||||
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
|
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
|
||||||
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
|
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
|
||||||
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
|
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
|
||||||
|
@ -281,7 +283,7 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
|
bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
|
||||||
if (bool_result && (prim->name() != MAKE_TUPLE)) {
|
if (bool_result && (prim->name() != MAKE_TUPLE) && (prim->name() != MAKE_LIST)) {
|
||||||
MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
|
MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
|
||||||
} else if (prim->name() == CAST) {
|
} else if (prim->name() == CAST) {
|
||||||
if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) {
|
if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) {
|
||||||
|
|
|
@ -450,7 +450,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
|
||||||
AnfNodeIndexSet node_set = manager->node_users()[node];
|
AnfNodeIndexSet node_set = manager->node_users()[node];
|
||||||
CNodePtr insert_node_new;
|
CNodePtr insert_node_new;
|
||||||
|
|
||||||
if (AnfNodeIsPrimitive(node, MAKE_TUPLE)) {
|
if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) {
|
||||||
MS_LOG(INFO) << "No need to insert redistribution op betweend make_tuple node and the next node";
|
MS_LOG(INFO) << "No need to insert redistribution op betweend make_tuple node and the next node";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -851,7 +851,8 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
||||||
FuncGraphManagerPtr manager = func_graph->manager();
|
FuncGraphManagerPtr manager = func_graph->manager();
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
|
||||||
if ((node->inputs().size() == 2) && AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE)) {
|
if ((node->inputs().size() == 2) &&
|
||||||
|
(AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) {
|
||||||
MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
|
MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -1055,7 +1056,7 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
|
||||||
MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is "
|
MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is "
|
||||||
<< node->fullname_with_scope();
|
<< node->fullname_with_scope();
|
||||||
}
|
}
|
||||||
auto tuple_shape_ptr = dyn_cast<abstract::TupleShape>(base_shape_ptr);
|
auto tuple_shape_ptr = dyn_cast<abstract::SequeueShape>(base_shape_ptr);
|
||||||
if (tuple_shape_ptr != nullptr) {
|
if (tuple_shape_ptr != nullptr) {
|
||||||
auto tuple_shape = tuple_shape_ptr->shape();
|
auto tuple_shape = tuple_shape_ptr->shape();
|
||||||
for (auto &shape : tuple_shape) {
|
for (auto &shape : tuple_shape) {
|
||||||
|
@ -1436,7 +1437,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
||||||
SetVirtualDatasetStrategy(cnode);
|
SetVirtualDatasetStrategy(cnode);
|
||||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||||
if (prim->name() == MAKE_TUPLE) {
|
if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto attrs = prim->attrs();
|
auto attrs = prim->attrs();
|
||||||
|
@ -2459,9 +2460,9 @@ Status ParallelInit() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
void HandleForwardMakeTuple(const std::vector<AnfNodePtr> &all_nodes) {
|
void HandleForwardMakeTupleAndMakeList(const std::vector<AnfNodePtr> &all_nodes) {
|
||||||
for (auto &node : all_nodes) {
|
for (auto &node : all_nodes) {
|
||||||
if (!AnfNodeIsPrimitive(node, MAKE_TUPLE)) {
|
if (!AnfNodeIsPrimitive(node, MAKE_TUPLE) && !AnfNodeIsPrimitive(node, MAKE_LIST)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2473,25 +2474,28 @@ void HandleForwardMakeTuple(const std::vector<AnfNodePtr> &all_nodes) {
|
||||||
|
|
||||||
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
|
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
auto make_tuple_user = manager->node_users()[cnode];
|
std::string op_type = AnfNodeIsPrimitive(node, MAKE_TUPLE) ? MAKE_TUPLE : MAKE_LIST;
|
||||||
if (make_tuple_user.size() != 1) {
|
|
||||||
MS_LOG(EXCEPTION) << "Now the make_tuple's user must be 1, but got " << make_tuple_user.size();
|
|
||||||
}
|
|
||||||
CNodePtr make_tuple_next_cnode = make_tuple_user.pop().first->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(make_tuple_next_cnode);
|
|
||||||
|
|
||||||
std::string make_tuple_user_prim_name = GetPrimName(make_tuple_next_cnode);
|
auto make_tuple_list_user = manager->node_users()[cnode];
|
||||||
if (!IsParallelCareNode(make_tuple_next_cnode)) {
|
if (make_tuple_list_user.size() != 1) {
|
||||||
MS_LOG(INFO) << "The make_tuple's user is " << make_tuple_user_prim_name << ", no need to set operator info";
|
MS_LOG(EXCEPTION) << "Now the " << op_type << "'s user must be 1, but got " << make_tuple_list_user.size();
|
||||||
|
}
|
||||||
|
CNodePtr make_tuple_list_next_cnode = make_tuple_list_user.pop().first->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(make_tuple_list_next_cnode);
|
||||||
|
|
||||||
|
std::string make_tuple__list_user_prim_name = GetPrimName(make_tuple_list_next_cnode);
|
||||||
|
if (!IsParallelCareNode(make_tuple_list_next_cnode)) {
|
||||||
|
MS_LOG(INFO) << "The " << op_type << "'s user is " << make_tuple__list_user_prim_name
|
||||||
|
<< ", no need to set operator info";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (make_tuple_next_cnode->inputs().size() != 2) {
|
if (make_tuple_list_next_cnode->inputs().size() != 2) {
|
||||||
MS_LOG(EXCEPTION) << "Now the make_tuple's user only support 1 input, but got "
|
MS_LOG(EXCEPTION) << "Now the " << op_type << "'s user only support 1 input, but got "
|
||||||
<< make_tuple_next_cnode->inputs().size() - 1;
|
<< make_tuple_list_next_cnode->inputs().size() - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "Set the make_tuple's operator info, and the op name is " << make_tuple_user_prim_name;
|
MS_LOG(INFO) << "Set the " << op_type << "'s operator info, and the op name is " << make_tuple__list_user_prim_name;
|
||||||
OperatorInfoPtr op_info = GetDistributeOperator(make_tuple_next_cnode);
|
OperatorInfoPtr op_info = GetDistributeOperator(make_tuple_list_next_cnode);
|
||||||
MS_EXCEPTION_IF_NULL(op_info);
|
MS_EXCEPTION_IF_NULL(op_info);
|
||||||
cnode->set_user_data<OperatorInfo>(op_info);
|
cnode->set_user_data<OperatorInfo>(op_info);
|
||||||
}
|
}
|
||||||
|
@ -2695,7 +2699,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
||||||
ReshapeInit(all_nodes);
|
ReshapeInit(all_nodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
HandleForwardMakeTuple(all_nodes);
|
HandleForwardMakeTupleAndMakeList(all_nodes);
|
||||||
|
|
||||||
// if the input or parameter has multiple users, check whether its split strategies are consistent.
|
// if the input or parameter has multiple users, check whether its split strategies are consistent.
|
||||||
CheckParameterSplit(all_nodes);
|
CheckParameterSplit(all_nodes);
|
||||||
|
|
|
@ -348,6 +348,16 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
|
||||||
}
|
}
|
||||||
auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
|
auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
|
||||||
return tuple;
|
return tuple;
|
||||||
|
} else if (py::isinstance<py::list>(shape_obj) && py::isinstance<py::list>(type_obj)) {
|
||||||
|
py::list shape_list = shape_obj.cast<py::list>();
|
||||||
|
py::list typeid_list = type_obj.cast<py::list>();
|
||||||
|
AbstractBasePtrList ptr_list;
|
||||||
|
for (size_t it = 0; it < shape_list.size(); ++it) {
|
||||||
|
auto tensor_it = PyListDtype2AbstractTensor(shape_list[it], typeid_list[it]);
|
||||||
|
ptr_list.push_back(tensor_it);
|
||||||
|
}
|
||||||
|
auto list = std::make_shared<abstract::AbstractList>(ptr_list);
|
||||||
|
return list;
|
||||||
} else if (shape_obj.is_none() && type_obj.is_none()) {
|
} else if (shape_obj.is_none() && type_obj.is_none()) {
|
||||||
// AbstractNone indicates there is no output for this CNode node.
|
// AbstractNone indicates there is no output for this CNode node.
|
||||||
auto abstract_none = std::make_shared<abstract::AbstractNone>();
|
auto abstract_none = std::make_shared<abstract::AbstractNone>();
|
||||||
|
|
|
@ -228,11 +228,19 @@ def get_bprop_virtual_div_operator(self):
|
||||||
dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
|
dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
|
||||||
return (dx,)
|
return (dx,)
|
||||||
|
|
||||||
dx = ()
|
if F.issubclass_(F.typeof(dout), mstype.tuple_):
|
||||||
input_nums = F.tuple_len(dout)
|
dx = ()
|
||||||
|
input_nums = F.tuple_len(dout)
|
||||||
|
for i in range(input_nums):
|
||||||
|
ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
|
||||||
|
dx = dx + (ele_grad,)
|
||||||
|
return (dx,)
|
||||||
|
|
||||||
|
dx = []
|
||||||
|
input_nums = F.list_len(dout)
|
||||||
for i in range(input_nums):
|
for i in range(input_nums):
|
||||||
ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
|
ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
|
||||||
dx = dx + (ele_grad,)
|
dx.append(ele_grad)
|
||||||
return (dx,)
|
return (dx,)
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
|
@ -92,6 +92,7 @@ dict_getitem = Primitive('dict_getitem')
|
||||||
dict_setitem = Primitive('dict_setitem')
|
dict_setitem = Primitive('dict_setitem')
|
||||||
tuple_div = Primitive("tuple_div")
|
tuple_div = Primitive("tuple_div")
|
||||||
tuple_len = Primitive("tuple_len")
|
tuple_len = Primitive("tuple_len")
|
||||||
|
list_len = Primitive("list_len")
|
||||||
tuple_reversed = Primitive("tuple_reversed")
|
tuple_reversed = Primitive("tuple_reversed")
|
||||||
make_range = Primitive("make_range")
|
make_range = Primitive("make_range")
|
||||||
make_tuple = Primitive('make_tuple')
|
make_tuple = Primitive('make_tuple')
|
||||||
|
|
|
@ -0,0 +1,188 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.common.api import _executor
|
||||||
|
from mindspore.nn import TrainOneStepCell, Momentum
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None, is_parameter=True):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.pack = P.Pack(axis=axis).shard(strategy1)
|
||||||
|
self.mul = P.Mul().shard(strategy2)
|
||||||
|
if is_parameter:
|
||||||
|
self.weight1 = Parameter(weight1, "w1")
|
||||||
|
else:
|
||||||
|
self.weight1 = weight1
|
||||||
|
self.weight2 = Parameter(weight2, "w2")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.pack([self.weight1, self.weight2])
|
||||||
|
out = self.mul(x, out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Net1(nn.Cell):
|
||||||
|
def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None):
|
||||||
|
super(Net1, self).__init__()
|
||||||
|
self.pack = P.Pack(axis=axis).shard(strategy1)
|
||||||
|
self.mul = P.Mul().shard(strategy2)
|
||||||
|
self.weight1 = Parameter(weight1, "w1")
|
||||||
|
self.weight2 = Parameter(weight2, "w2")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.mul(x, self.weight1)
|
||||||
|
out = self.pack([out, self.weight2])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Net2(nn.Cell):
|
||||||
|
def __init__(self, weight1, weight2, weight3, axis=0, strategy1=None, strategy2=None, is_parameter=True):
|
||||||
|
super(Net2, self).__init__()
|
||||||
|
self.pack = P.Pack(axis=axis).shard(strategy1)
|
||||||
|
self.mul = P.Mul().shard(strategy2)
|
||||||
|
if is_parameter:
|
||||||
|
self.weight1 = Parameter(weight1, "w1")
|
||||||
|
else:
|
||||||
|
self.weight1 = weight1
|
||||||
|
self.weight2 = Parameter(weight2, "w2")
|
||||||
|
self.weight3 = Parameter(weight2, "w3")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.pack([self.weight1, self.weight2, self.weight3])
|
||||||
|
out = self.mul(x, out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
_w1 = Tensor(np.ones([48, 64]), dtype=ms.float32)
|
||||||
|
_w2 = Tensor(np.ones([48, 64]), dtype=ms.float32)
|
||||||
|
_w3 = Tensor(np.ones([48, 64]), dtype=ms.float32)
|
||||||
|
_x = Tensor(np.ones([2, 48, 64]), dtype=ms.float32)
|
||||||
|
_x1 = Tensor(np.ones([48, 64]), dtype=ms.float32)
|
||||||
|
_x2 = Tensor(np.ones([3, 48, 64]), dtype=ms.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def compile_net(net):
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
|
train_net = TrainOneStepCell(net, optimizer)
|
||||||
|
train_net.set_auto_parallel()
|
||||||
|
_executor.compile(train_net, _x)
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
|
||||||
|
|
||||||
|
def compile_net1(net):
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
|
train_net = TrainOneStepCell(net, optimizer)
|
||||||
|
train_net.set_auto_parallel()
|
||||||
|
_executor.compile(train_net, _x1)
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
|
||||||
|
|
||||||
|
def compile_net2(net):
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
|
train_net = TrainOneStepCell(net, optimizer)
|
||||||
|
train_net.set_auto_parallel()
|
||||||
|
_executor.compile(train_net, _x2)
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
|
||||||
|
|
||||||
|
def test_pack_parameter():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((4, 2), (4, 2))
|
||||||
|
strategy2 = ((1, 4, 2), (1, 4, 2))
|
||||||
|
net = Net(_w1, _w2, 0, strategy1, strategy2)
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pack_parameter_no_full_split():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((2, 2), (2, 2))
|
||||||
|
strategy2 = ((1, 4, 2), (1, 4, 2))
|
||||||
|
net = Net(_w1, _w2, 0, strategy1, strategy2)
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pack_tensor_and_parameter():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((4, 2), (4, 2))
|
||||||
|
strategy2 = ((1, 4, 2), (1, 4, 2))
|
||||||
|
net = Net(_w1, _w2, 0, strategy1, strategy2, False)
|
||||||
|
compile_net(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pack_output():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((4, 2), (4, 2))
|
||||||
|
strategy2 = ((4, 2), (4, 2))
|
||||||
|
net = Net1(_w1, _w2, 0, strategy1, strategy2)
|
||||||
|
compile_net1(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pack_output_axis1():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((4, 2), (4, 2))
|
||||||
|
strategy2 = ((4, 2), (4, 2))
|
||||||
|
net = Net1(_w1, _w2, 1, strategy1, strategy2)
|
||||||
|
compile_net1(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pack_output_no_full_split():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = ((2, 2), (2, 2))
|
||||||
|
strategy2 = ((4, 2), (4, 2))
|
||||||
|
net = Net1(_w1, _w2, 0, strategy1, strategy2)
|
||||||
|
compile_net1(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pack_no_strategy():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = None
|
||||||
|
strategy2 = ((4, 2), (4, 2))
|
||||||
|
net = Net1(_w1, _w2, 0, strategy1, strategy2)
|
||||||
|
compile_net1(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pack_no_strategy_axis1():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
|
strategy1 = None
|
||||||
|
strategy2 = ((4, 2), (4, 2))
|
||||||
|
net = Net1(_w1, _w2, 1, strategy1, strategy2)
|
||||||
|
compile_net1(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pack_auto_parallel():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||||
|
net = Net1(_w1, _w2, 0)
|
||||||
|
compile_net1(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pack_auto_parallel_axis1():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||||
|
net = Net1(_w1, _w2, 1)
|
||||||
|
compile_net1(net)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pack_auto_parallel_3_tensor():
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||||
|
net = Net2(_w1, _w2, _w3)
|
||||||
|
compile_net2(net)
|
Loading…
Reference in New Issue