forked from mindspore-Ecosystem/mindspore
add parallel op for resizenearestneighbor
This commit is contained in:
parent
c3ce9c56c6
commit
c42081619e
|
@ -35,14 +35,7 @@ Status ResizeBilinearInfo::GetAttrs() {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
if (size_[0] != size_[1]) {
|
||||
MS_LOG(ERROR) << name_ << ": The second two elements of size must be the same, but got (" << size_[0] << ", "
|
||||
<< size_[1] << ")";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
align_corners_ = GetBoolAttr(ALIGN_CORNERS);
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The input size is " << size_ << ", align_corners is " << align_corners_;
|
||||
|
||||
return SUCCESS;
|
||||
|
@ -85,7 +78,15 @@ Status ResizeBilinearInfo::InferDevMatrixShape() {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
if (stra[0].size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy must be 4, but got " << stra[0].size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
dev_matrix_shape_ = stra[0];
|
||||
slice_size_ = size_;
|
||||
slice_size_[0] = slice_size_[0] / dev_matrix_shape_[2];
|
||||
slice_size_[1] = slice_size_[1] / dev_matrix_shape_[3];
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -135,5 +136,40 @@ Status ResizeBilinearInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|||
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void ResizeBilinearInfo::ReplaceNodeInputOrAttrs() {
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode_->input(0));
|
||||
prim->set_attr(SIZE, MakeValue(slice_size_));
|
||||
}
|
||||
|
||||
Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
|
||||
// check input strategy
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Check input strategy failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// check output strategy
|
||||
if (CheckStrategyValue(strategy, outputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Check output strategy failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> ResizeNearestNeighborInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
Shape multiples_split(inputs_shape_[0].size(), 1);
|
||||
Shapes splittable_inputs = {multiples_split};
|
||||
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": generate strategies failed";
|
||||
}
|
||||
|
||||
return sp_vector;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,11 +47,11 @@ class ResizeBilinearInfo : public OperatorInfo {
|
|||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
||||
void ReplaceNodeInputOrAttrs() override;
|
||||
|
||||
private:
|
||||
std::vector<int64_t> size_; // four integers, NCHW
|
||||
bool align_corners_;
|
||||
std::vector<int64_t> size_;
|
||||
std::vector<int64_t> slice_size_;
|
||||
bool align_corners_ = false;
|
||||
};
|
||||
|
||||
class ResizeNearestNeighborInfo : public ResizeBilinearInfo {
|
||||
|
@ -60,6 +60,10 @@ class ResizeNearestNeighborInfo : public ResizeBilinearInfo {
|
|||
const PrimitiveAttrs &attrs)
|
||||
: ResizeBilinearInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~ResizeNearestNeighborInfo() override = default;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
};
|
||||
|
||||
} // namespace parallel
|
||||
|
|
|
@ -3830,7 +3830,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check('the dimension of input_x', len(x_shape), '', 4, Rel.EQ, self.name)
|
||||
return tuple(x_shape)[:-2] + tuple(self.size)
|
||||
return tuple(x_shape)[:-2] + tuple(super().get_attr_dict()['size'])
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
|
||||
|
|
|
@ -43,17 +43,20 @@ class Net2(Cell):
|
|||
'''
|
||||
create the test Net
|
||||
'''
|
||||
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
|
||||
def __init__(self, conv2d_weight, mul_weight, out_channel, kernel_size, pad_mode, stride,
|
||||
strategy1=None, strategy2=None):
|
||||
super(Net2, self).__init__()
|
||||
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
|
||||
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
||||
self.conv2d_weight = Parameter(conv2d_weight, "w1")
|
||||
self.resize_neighbor = P.ResizeNearestNeighbor((16, 16)).shard(strategy2)
|
||||
self.mul = P.Mul()
|
||||
self.mul_weight = Parameter(mul_weight, "w2")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv2d(x, self.conv2d_weight)
|
||||
out = self.resize_neighbor(out)
|
||||
out = self.mul(out, self.mul_weight)
|
||||
return out
|
||||
|
||||
class Net3(Cell):
|
||||
|
@ -76,6 +79,7 @@ class Net3(Cell):
|
|||
|
||||
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([32, 8, 16, 16]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net, inputs=_x):
|
||||
|
@ -130,21 +134,21 @@ def test_neighbor_data_parallel():
|
|||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
|
||||
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_neighbor_model_parallel1():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((4, 2, 1, 1),)
|
||||
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
|
||||
strategy2 = ((2, 2, 2, 2),)
|
||||
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_neighbor_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
||||
net = Net2(_w1, _w2, out_channel=8, kernel_size=2, pad_mode="same", stride=1)
|
||||
compile_net(net)
|
||||
|
|
Loading…
Reference in New Issue