forked from mindspore-Ecosystem/mindspore
add the new operatorInfo for parallel: CumSum
This commit is contained in:
parent
4cc59dee17
commit
8042c88223
|
@ -209,6 +209,7 @@ REGISTER(DSDMatmulInfo);
|
||||||
REGISTER(UniformRealInfo);
|
REGISTER(UniformRealInfo);
|
||||||
REGISTER(ResizeBilinearInfo);
|
REGISTER(ResizeBilinearInfo);
|
||||||
REGISTER(ResizeNearestNeighborInfo);
|
REGISTER(ResizeNearestNeighborInfo);
|
||||||
|
REGISTER(CumSumInfo);
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -204,6 +204,79 @@ std::vector<StrategyPtr> Softmax::GenerateOpStrategies(int64_t stage_id) {
|
||||||
return sp_vector;
|
return sp_vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status CumSumInfo::GetAttrs() {
|
||||||
|
if (input_value_.size() != CUMSUM_INPUT_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid inputs size " << input_value_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!input_value_.back()->isa<Int64Imm>()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The type of axis is not int64_t";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t axis = GetValue<int64_t>(input_value_.back());
|
||||||
|
|
||||||
|
if (inputs_shape_.empty()) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t dim = SizeToLong(inputs_shape_[0].size());
|
||||||
|
if ((axis > dim - 1) || (axis < -dim)) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range [" << -dim << ", " << dim << ")";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (axis < 0) {
|
||||||
|
axis_ = dim + axis;
|
||||||
|
} else {
|
||||||
|
axis_ = axis;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << name_ << ": The axis is " << axis;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CumSumInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
Strategys stra = strategy->GetInputDim();
|
||||||
|
Dimensions input_strategy = stra.at(0);
|
||||||
|
if (input_strategy.size() <= IntToSize(axis_)) {
|
||||||
|
MS_LOG(ERROR) << "The " << name_ << " input strategy length: " << input_strategy.size() << ", is less ot equal to "
|
||||||
|
<< axis_;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
auto axis_split = input_strategy[axis_];
|
||||||
|
if (axis_split > 1) {
|
||||||
|
MS_LOG(ERROR) << "Currently, CumSum does not support the sharding strategies which splits axis.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> CumSumInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
|
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||||
|
if (axis_ < 0 || IntToSize(axis_) >= inputs_shape_[0].size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Wrong axis value: " << axis_;
|
||||||
|
}
|
||||||
|
// Currently, CumSum does not support the sharding strategies which splits axis.
|
||||||
|
input0_split[axis_] = 0;
|
||||||
|
Shapes splittable_inputs = {input0_split};
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> sp_vector;
|
||||||
|
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
|
||||||
|
}
|
||||||
|
return sp_vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CumSumInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||||
|
|
||||||
Status ActivationBase::InferDevMatrixShape() {
|
Status ActivationBase::InferDevMatrixShape() {
|
||||||
Strategys stra = strategy_->GetInputDim();
|
Strategys stra = strategy_->GetInputDim();
|
||||||
Dimensions input_strategy = stra.at(0);
|
Dimensions input_strategy = stra.at(0);
|
||||||
|
|
|
@ -119,6 +119,23 @@ class Softmax : public ActivationBase {
|
||||||
std::vector<int64_t> axis_;
|
std::vector<int64_t> axis_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class CumSumInfo : public ActivationBase {
|
||||||
|
public:
|
||||||
|
explicit CumSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {}
|
||||||
|
~CumSumInfo() override = default;
|
||||||
|
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||||
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||||
|
Status GetAttrs() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64_t axis_ = -1;
|
||||||
|
};
|
||||||
|
|
||||||
class SoftmaxInfo : public Softmax {
|
class SoftmaxInfo : public Softmax {
|
||||||
public:
|
public:
|
||||||
SoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
SoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
|
|
@ -53,6 +53,7 @@ constexpr size_t SOFTMAX_ATTR_SIZE = 1;
|
||||||
constexpr size_t ACTIVATION_INPUTS_SIZE = 1;
|
constexpr size_t ACTIVATION_INPUTS_SIZE = 1;
|
||||||
constexpr size_t ACTIVATION_OUTPUTS_SIZE = 1;
|
constexpr size_t ACTIVATION_OUTPUTS_SIZE = 1;
|
||||||
constexpr size_t EXPANDDIMS_INPUT_SIZE = 2;
|
constexpr size_t EXPANDDIMS_INPUT_SIZE = 2;
|
||||||
|
constexpr size_t CUMSUM_INPUT_SIZE = 2;
|
||||||
constexpr size_t DROPOUT_DO_MASK_CNODE_INPUT_SIZE = 4;
|
constexpr size_t DROPOUT_DO_MASK_CNODE_INPUT_SIZE = 4;
|
||||||
constexpr size_t DROPOUT_GEN_MASK_CNODE_INPUT_SIZE = 3;
|
constexpr size_t DROPOUT_GEN_MASK_CNODE_INPUT_SIZE = 3;
|
||||||
constexpr size_t DROPOUT_GEN_MASK_INDEX = 2;
|
constexpr size_t DROPOUT_GEN_MASK_INDEX = 2;
|
||||||
|
@ -414,6 +415,7 @@ constexpr char GATHERD[] = "GatherD";
|
||||||
constexpr char DSD_MATMUL[] = "DSDMatmul";
|
constexpr char DSD_MATMUL[] = "DSDMatmul";
|
||||||
constexpr char RESIZE_BILINEAR[] = "ResizeBilinear";
|
constexpr char RESIZE_BILINEAR[] = "ResizeBilinear";
|
||||||
constexpr char RESIZE_NEAREST_NEIGHBOR[] = "ResizeNearestNeighbor";
|
constexpr char RESIZE_NEAREST_NEIGHBOR[] = "ResizeNearestNeighbor";
|
||||||
|
constexpr char CUMSUM[] = "CumSum";
|
||||||
|
|
||||||
// pipeline
|
// pipeline
|
||||||
constexpr size_t PIPELINE_FUSTION_OFFSET = 100;
|
constexpr size_t PIPELINE_FUSTION_OFFSET = 100;
|
||||||
|
|
|
@ -168,7 +168,7 @@ bool IsSplittableOperator(const std::string &op_name) {
|
||||||
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
|
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
|
||||||
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
|
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
|
||||||
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
|
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
|
||||||
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR};
|
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR, CUMSUM};
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
auto iter = splittable_op.find(op_name);
|
auto iter = splittable_op.find(op_name);
|
||||||
|
|
|
@ -0,0 +1,131 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.common.api import _cell_graph_executor
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||||
|
|
||||||
|
grad_all = C.GradOperation(get_all=True)
|
||||||
|
|
||||||
|
class NetWithLoss(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(NetWithLoss, self).__init__()
|
||||||
|
self.loss = VirtualLoss()
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
predict = self.network(x, y)
|
||||||
|
return self.loss(predict)
|
||||||
|
|
||||||
|
class GradWrap(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(GradWrap, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return grad_all(self.network)(x, y)
|
||||||
|
|
||||||
|
def compile_net(net, x, y):
|
||||||
|
net.set_auto_parallel()
|
||||||
|
net.set_train()
|
||||||
|
_cell_graph_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
def test_cumsum_semi():
|
||||||
|
"""
|
||||||
|
Feature: CumSum operatorInfo in parallel.
|
||||||
|
Description: MatMul->CumSum
|
||||||
|
Expectation: Currently, CumSum does not support the axis dimension split. compile done without error.
|
||||||
|
"""
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.matmul1 = P.MatMul().shard(((16, 1), (1, 1)))
|
||||||
|
self.cumsum = P.CumSum().shard(((16, 1),))
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = self.matmul1(x, y)
|
||||||
|
out = self.cumsum(out, 0)
|
||||||
|
return out
|
||||||
|
|
||||||
|
size = 16
|
||||||
|
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||||
|
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||||
|
|
||||||
|
net = GradWrap(NetWithLoss(Net()))
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
compile_net(net, x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cumsum_semi2():
|
||||||
|
"""
|
||||||
|
Feature: CumSum operatorInfo in parallel.
|
||||||
|
Description: MatMul->CumSum
|
||||||
|
Expectation: Compile done without error.
|
||||||
|
"""
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.matmul1 = P.MatMul().shard(((16, 1), (1, 1)))
|
||||||
|
self.cumsum = P.CumSum().shard(((1, 16),))
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = self.matmul1(x, y)
|
||||||
|
out = self.cumsum(out, 0)
|
||||||
|
return out
|
||||||
|
|
||||||
|
size = 16
|
||||||
|
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||||
|
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||||
|
|
||||||
|
net = GradWrap(NetWithLoss(Net()))
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||||
|
compile_net(net, x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cumsum_auto():
|
||||||
|
"""
|
||||||
|
Feature: CumSum operatorInfo in parallel.
|
||||||
|
Description: MatMul->CumSum
|
||||||
|
Expectation: Compile done without error.
|
||||||
|
"""
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.matmul1 = P.MatMul().shard(((16, 1), (1, 1)))
|
||||||
|
self.cumsum = P.CumSum()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = self.matmul1(x, y)
|
||||||
|
out = self.cumsum(out, -1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
size = 16
|
||||||
|
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||||
|
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||||
|
|
||||||
|
net = GradWrap(NetWithLoss(Net()))
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||||
|
compile_net(net, x, y)
|
Loading…
Reference in New Issue