add tile op
This commit is contained in:
parent
669a8969c7
commit
6a6e2bd271
|
@ -198,8 +198,8 @@ double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs
|
|||
// this operator uses
|
||||
double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
int32_t) const {
|
||||
TensorInfo input0_info = inputs[0];
|
||||
Shape input0_slice_shape = input0_info.slice_shape();
|
||||
TensorInfo input0 = inputs[0];
|
||||
Shape input0_slice_shape = input0.slice_shape();
|
||||
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||
}
|
||||
|
||||
|
@ -240,12 +240,16 @@ double SoftmaxCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, c
|
|||
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
// this operator uses
|
||||
double SoftmaxCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
double SoftmaxCost::GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &outputs,
|
||||
int32_t) const {
|
||||
// In the forward phase, the computation cost = slice(A)
|
||||
TensorInfo input0 = inputs[0];
|
||||
Shape input0_slice_shape = input0.slice_shape();
|
||||
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||
if (outputs.empty() || outputs_type_lengths_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The outputs or outputs_type_length is empty";
|
||||
}
|
||||
|
||||
// use output for Tile operator
|
||||
TensorInfo output_info = outputs[0];
|
||||
Shape output_slice_shape = output_info.slice_shape();
|
||||
return ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]);
|
||||
}
|
||||
|
||||
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
|
||||
|
|
|
@ -195,6 +195,8 @@ class SoftmaxCost : public OperatorCost {
|
|||
int32_t) const override;
|
||||
};
|
||||
using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>;
|
||||
using TileCost = SoftmaxCost;
|
||||
using TileCostPtr = std::shared_ptr<TileCost>;
|
||||
|
||||
class TmpIdentityCost : public OperatorCost {
|
||||
public:
|
||||
|
|
|
@ -133,6 +133,7 @@ REGISTER(SigmoidCrossEntropyWithLogitsInfo);
|
|||
REGISTER(SquareInfo);
|
||||
REGISTER(GatherV2PInfo);
|
||||
REGISTER(EmbeddingLookupInfo);
|
||||
REGISTER(TileInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -37,5 +37,6 @@
|
|||
#include "frontend/parallel/ops_info/transpose_info.h"
|
||||
#include "frontend/parallel/ops_info/virtual_dataset_info.h"
|
||||
#include "frontend/parallel/ops_info/gather_v2_p_info.h"
|
||||
#include "frontend/parallel/ops_info/tile_info.h"
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
||||
|
|
|
@ -182,6 +182,7 @@ constexpr char RELU[] = "ReLU";
|
|||
constexpr char ONEHOT[] = "OneHot";
|
||||
constexpr char DROPOUT_DO_MASK[] = "DropoutDoMask";
|
||||
constexpr char DROPOUT_GEN_MASK[] = "DropoutGenMask";
|
||||
constexpr char TILE[] = "Tile";
|
||||
constexpr char REDUCE_MAX[] = "ReduceMax";
|
||||
constexpr char REDUCE_MIN[] = "ReduceMin";
|
||||
constexpr char REDUCE_SUM[] = "ReduceSum";
|
||||
|
|
|
@ -0,0 +1,252 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ops_info/tile_info.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
// get the multiples
|
||||
Status TileInfo::GetAttrs() {
|
||||
if (input_value_.size() < 2) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of input value is smaller than 2.";
|
||||
return FAILED;
|
||||
}
|
||||
if (input_value_[1] == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << ": The multiples is null.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<ValuePtr> elements;
|
||||
ValueTuplePtr multiples = input_value_[1]->cast<ValueTuplePtr>();
|
||||
if (multiples == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << ": Input_value_[1] must be ValueTuplePtr.";
|
||||
return FAILED;
|
||||
}
|
||||
elements = multiples->value();
|
||||
if (elements.size() != outputs_shape_[0].size()) {
|
||||
MS_LOG(ERROR) << name_ << ": Elements size must equal to outputs shape[0] size.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
for (auto &element : elements) {
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
if (element->isa<Int32Imm>()) {
|
||||
int32_t axis = element->cast<Int32ImmPtr>()->value();
|
||||
full_multiples_.push_back(axis);
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": The value of axis must be int32.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status TileInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
Shapes multiples = {full_multiples_};
|
||||
if (CheckStrategyValue(strategy, multiples, is_auto_parallel_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status TileInfo::InferDevMatrixShape() {
|
||||
MS_EXCEPTION_IF_NULL(strategy_);
|
||||
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||
if (stra.empty()) {
|
||||
MS_LOG(ERROR) << name_ << "The strategy is empty";
|
||||
return FAILED;
|
||||
}
|
||||
if (full_multiples_.size() != stra[0].size()) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
dev_matrix_shape_ = stra[0];
|
||||
|
||||
slice_multiples_ = full_multiples_;
|
||||
for (size_t i = 0; i < full_multiples_.size(); ++i) {
|
||||
slice_multiples_[i] = slice_multiples_[i] / dev_matrix_shape_[i];
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status TileInfo::InferTensorMap() {
|
||||
TensorMap input_tensor_map;
|
||||
TensorMap output_tensor_map;
|
||||
if (inputs_shape_.empty() || outputs_shape_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << "The inputs or outputs' shape is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// the input tensor cannot be split
|
||||
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
|
||||
input_tensor_map.push_back(MAP_NONE);
|
||||
}
|
||||
|
||||
// cannot use dev_matrix_shape_ replace outputs_shape_[0], because it may not be fully split in all devices.
|
||||
int32_t size = SizeToInt(outputs_shape_[0].size());
|
||||
for (int i = 0; i < size; ++i) {
|
||||
output_tensor_map.push_back(size - i - 1);
|
||||
}
|
||||
|
||||
inputs_tensor_map_.push_back(input_tensor_map);
|
||||
outputs_tensor_map_.push_back(output_tensor_map);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status TileInfo::InferMirrorOps() {
|
||||
mirror_ops_.clear();
|
||||
Shape input_tensor_map = inputs_tensor_map_[0];
|
||||
std::vector<Group> group;
|
||||
if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Create group for input failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (group.empty()) {
|
||||
MS_LOG(INFO) << name_ << ": The mirror group is empty.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
OperatorVector input_op, multiples_op;
|
||||
input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
|
||||
mirror_ops_.push_back(input_op);
|
||||
mirror_ops_.push_back(multiples_op);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status TileInfo::InferTensorInfo() {
|
||||
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid args";
|
||||
return FAILED;
|
||||
}
|
||||
// infer tensor layout
|
||||
TensorLayout input_layout, output_layout;
|
||||
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
|
||||
return FAILED;
|
||||
}
|
||||
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
TensorInfo input_tensor_info(input_layout);
|
||||
TensorInfo output_tensor_info(output_layout);
|
||||
|
||||
inputs_tensor_info_.push_back(input_tensor_info);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void TileInfo::UpdateMultiples(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "The size of tile cnode's inputs must be 3";
|
||||
}
|
||||
|
||||
if (!IsValueNode<ValueTuple>(cnode->input(2))) {
|
||||
MS_LOG(EXCEPTION) << "The input[2] of tile cnode is not ValueTuple.";
|
||||
}
|
||||
|
||||
auto func_graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
ValuePtr new_multiples = MakeValue(slice_multiples_);
|
||||
AnfNodePtr val = NewValueNode(new_multiples);
|
||||
(void)manager->Replace(cnode->input(2), val);
|
||||
}
|
||||
|
||||
std::shared_ptr<std::vector<std::vector<int32_t>>> TileInfo::GenerateBatchStrategies() {
|
||||
if (InferAttrs() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
|
||||
}
|
||||
Shapes multiples_shape = {full_multiples_};
|
||||
split_flag_list_ = {true};
|
||||
return GenerateBatchStrategiesBySplitFlag(multiples_shape, split_flag_list_);
|
||||
}
|
||||
|
||||
Status TileInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status TileInfo::GenerateStrategies(int32_t stage_id) {
|
||||
if (InferAttrs() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer attrs failed";
|
||||
return FAILED;
|
||||
}
|
||||
Shape multiples_split(full_multiples_.size(), 1);
|
||||
Shapes splittable_inputs = {multiples_split};
|
||||
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
is_auto_parallel_ = true;
|
||||
Shapes tmp_inputs_shape = {full_multiples_};
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
size_t success = 0;
|
||||
for (auto &sp : sp_vector) {
|
||||
PrintStrategy(sp);
|
||||
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
||||
success++;
|
||||
MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
|
||||
PrintStrategy(sp);
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status TileInfo::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": Init success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status TileInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TILE_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TILE_INFO_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class TileInfo : public OperatorInfo {
|
||||
public:
|
||||
TileInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TileCost>(false)) {}
|
||||
~TileInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
Status GenerateStrategies(int32_t) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override;
|
||||
void UpdateMultiples(const CNodePtr &cnode);
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferMirrorOps() override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferTensorInfo() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
|
||||
private:
|
||||
std::vector<int32_t> full_multiples_;
|
||||
std::vector<int32_t> slice_multiples_;
|
||||
};
|
||||
|
||||
using TileInfoPtr = std::shared_ptr<TileInfo>;
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_TILE_INFO_H_
|
|
@ -260,7 +260,7 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
|
||||
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP,
|
||||
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT,
|
||||
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2,
|
||||
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE,
|
||||
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS};
|
||||
// clang-format on
|
||||
|
||||
|
|
|
@ -1890,8 +1890,24 @@ void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePt
|
|||
ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>());
|
||||
}
|
||||
|
||||
void HandleTileNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() < 3 || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim->name() != TILE) {
|
||||
return;
|
||||
}
|
||||
|
||||
TileInfoPtr tile = std::dynamic_pointer_cast<TileInfo>(distribute_operator);
|
||||
MS_EXCEPTION_IF_NULL(tile);
|
||||
tile->UpdateMultiples(cnode);
|
||||
}
|
||||
|
||||
void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
|
||||
HandleDropoutNode(distribute_operator, cnode);
|
||||
HandleTileNode(distribute_operator, cnode);
|
||||
}
|
||||
|
||||
std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) {
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, weight, weight2, strategy1=None, strategy2=None, is_parameter=True):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().set_strategy(strategy1)
|
||||
self.tile = P.Tile().set_strategy(strategy2)
|
||||
if is_parameter:
|
||||
self.weight = Parameter(weight, "w1")
|
||||
else:
|
||||
self.weight = weight
|
||||
self.mul2 = P.Mul()
|
||||
self.weight2 = Parameter(weight2, "w2")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.tile(self.weight, (8, 4, 2))
|
||||
out = self.mul(x, out)
|
||||
out = self.mul2(out, self.weight2)
|
||||
return out
|
||||
|
||||
|
||||
class Net2(Cell):
|
||||
def __init__(self, weight2, strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().set_strategy(strategy1)
|
||||
self.tile = P.Tile().set_strategy(strategy2)
|
||||
self.weight2 = Parameter(weight2, "w2")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.mul(x, self.weight2)
|
||||
out = self.tile(out, (8, 8, 4, 2))
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([16, 16, 16]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
context.set_context(save_graphs=True)
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_tile_parameter():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
strategy2 = ((2, 2, 2),)
|
||||
net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_tile_parameter_no_full_split():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
strategy2 = ((2, 2, 1),)
|
||||
net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_tile_tensor():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
strategy2 = ((2, 2, 2),)
|
||||
net = Net(_w1, _w2, strategy1, strategy2, is_parameter=False)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_tile_tensor_no_full_split():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
strategy2 = ((2, 2, 1),)
|
||||
net = Net(_w1, _w2, strategy1, strategy2, is_parameter=False)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_tile_output():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
strategy2 = ((1, 2, 2, 2),)
|
||||
net = Net2(_w2, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
def test_tile_output_no_full_split():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
strategy2 = ((1, 2, 1, 2),)
|
||||
net = Net2(_w2, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_tile_no_strategy():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
strategy2 = None
|
||||
net = Net2(_w2, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
def test_tile_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net2(_w2)
|
||||
compile_net(net)
|
Loading…
Reference in New Issue