!18546 add parallel operator for maxpool and avgpool
Merge pull request !18546 from yangzhenzhang/add-parallel-op-for-maxpool
This commit is contained in:
commit
504ffc6292
|
@ -254,6 +254,7 @@ class GeLUCost : public SqrtCost {
|
|||
using BesselI0eCost = GeLUCost;
|
||||
using BesselI1eCost = GeLUCost;
|
||||
using L2NormalizeCost = GeLUCost;
|
||||
using MaxPoolCost = GeLUCost;
|
||||
|
||||
class SoftmaxCost : public OperatorCost {
|
||||
public:
|
||||
|
|
|
@ -198,6 +198,8 @@ REGISTER(ScatterUpdateInfo);
|
|||
REGISTER(VirtualOutputInfo);
|
||||
REGISTER(Conv2DInfo);
|
||||
REGISTER(BatchNormInfo);
|
||||
REGISTER(MaxPoolInfo);
|
||||
REGISTER(AvgPoolInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ops_info/maxpool_info.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status MaxPoolInfo::GetAttrs() {
|
||||
// kernel_size
|
||||
kernel_size_ = GetTupleIntAttr(KERNEL_SIZE);
|
||||
if (kernel_size_.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of kernel_size must be 4, but got " << kernel_size_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (kernel_size_[0] != 1 || kernel_size_[1] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The first two elements of kernel_size must be 1, but got (" << kernel_size_[0] << ", "
|
||||
<< kernel_size_[1] << ")";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// pad_mode
|
||||
pad_mode_ = GetIntAttr(PAD_MODE);
|
||||
if (pad_mode_ != POOL_PAD_MODE_VALID && pad_mode_ != POOL_PAD_MODE_SAME) {
|
||||
MS_LOG(ERROR) << name_ << ": The pad_mode value is invalid: " << pad_mode_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// stride
|
||||
stride_ = GetTupleIntAttr(STRIDES);
|
||||
if (stride_.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of stride must be 4, but got " << stride_.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (stride_[0] != 1 || stride_[1] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The first two elements of stride must be 1, but got (" << stride_[0] << ", "
|
||||
<< stride_[1] << ")";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// format
|
||||
format_ = GetStringAttr(FORMAT);
|
||||
if (format_ != NCHW) {
|
||||
MS_LOG(ERROR) << name_ << ": The format only support 'NCHW', but got " << format_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The kernel size is " << kernel_size_ << ", pad mode is " << pad_mode_ << ", pad list is "
|
||||
<< ", stride is " << stride_ << ", format is " << format_;
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MaxPoolInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
|
||||
if (h_strategy > 1) {
|
||||
if (kernel_size_[2] > stride_[2]) {
|
||||
MS_LOG(ERROR) << name_ << ": It does not support to split H dimension when kernel_size > stride";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
|
||||
if (h_slice_shape % stride_[2] != 0) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": It does not support to split H dimension when kernel_size <= stride but slice shape is not "
|
||||
"divisible by stride ";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (w_strategy > 1) {
|
||||
if (kernel_size_[3] > stride_[3]) {
|
||||
MS_LOG(ERROR) << name_ << ": It does not support to split W dimension when kernel_size > stride";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
|
||||
if (w_slice_shape % stride_[3] != 0) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": It does not support to split W dimension when kernel_size <= stride but slice shape is not "
|
||||
"divisible by stride ";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MaxPoolInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||
if (stra.size() != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy must be 1, but got " << stra.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Dimensions input_strategy = stra[0];
|
||||
if (input_strategy.size() != 4) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of input strategy must be 4, but got" << input_strategy.size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
|
||||
if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MaxPoolInfo::InferDevMatrixShape() {
|
||||
// the strategy is (n, c, h, w)
|
||||
// the dev matrix is (n, c, h, w)
|
||||
MS_EXCEPTION_IF_NULL(strategy_);
|
||||
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||
if (stra.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
dev_matrix_shape_ = stra[0];
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MaxPoolInfo::InferTensorMap() {
|
||||
// input_strategy: (n, c, h, w)
|
||||
// output_strategy: (n, c, h, w)
|
||||
// dev_matrix: (n, c, h, w)
|
||||
TensorMap input_tensor_map = {3, 2, 1, 0};
|
||||
TensorMap output_tensor_map = {3, 2, 1, 0};
|
||||
|
||||
(void)inputs_tensor_map_.emplace_back(std::move(input_tensor_map));
|
||||
(void)outputs_tensor_map_.emplace_back(std::move(output_tensor_map));
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MaxPoolInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
std::vector<StrategyPtr> MaxPoolInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||
Shapes splittable_inputs = {input0_split};
|
||||
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
|
||||
}
|
||||
return sp_vector;
|
||||
}
|
||||
|
||||
Status MaxPoolInfo::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": Init success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MaxPoolInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MAXPOOL_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MAXPOOL_INFO_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class MaxPoolInfo : public OperatorInfo {
|
||||
public:
|
||||
MaxPoolInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<MaxPoolCost>()) {}
|
||||
~MaxPoolInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
||||
|
||||
private:
|
||||
std::vector<int64_t> kernel_size_; // four integers
|
||||
int64_t pad_mode_ = 0; // "same": 1; "valid": 2;
|
||||
std::vector<int64_t> stride_; // four integers
|
||||
std::string format_;
|
||||
};
|
||||
|
||||
class AvgPoolInfo : public MaxPoolInfo {
|
||||
public:
|
||||
AvgPoolInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: MaxPoolInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~AvgPoolInfo() override = default;
|
||||
};
|
||||
|
||||
constexpr int64_t POOL_PAD_MODE_SAME = 1;
|
||||
constexpr int64_t POOL_PAD_MODE_VALID = 2;
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_MAXPOOL_INFO_H_
|
|
@ -57,5 +57,6 @@
|
|||
#include "frontend/parallel/ops_info/virtual_output_info.h"
|
||||
#include "frontend/parallel/ops_info/conv2d_info.h"
|
||||
#include "frontend/parallel/ops_info/batchnorm_info.h"
|
||||
#include "frontend/parallel/ops_info/maxpool_info.h"
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
||||
|
|
|
@ -236,6 +236,7 @@ constexpr char ACTIVATION[] = "Activation";
|
|||
constexpr char PRELU[] = "PReLU";
|
||||
constexpr char FLOORDIV[] = "FloorDiv";
|
||||
constexpr char MAXPOOL[] = "MaxPool";
|
||||
constexpr char AVGPOOL[] = "AvgPool";
|
||||
constexpr char MAXPOOLV2[] = "MaxPoolV2";
|
||||
constexpr char L2_NORMALIZE[] = "L2Normalize";
|
||||
constexpr char TRANSPOSE[] = "Transpose";
|
||||
|
|
|
@ -158,7 +158,7 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
// clang-format off
|
||||
static const std::set<std::string> splittable_op =
|
||||
{MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
|
||||
FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
|
||||
FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, AVGPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
|
||||
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
|
||||
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, STACK,
|
||||
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, pool_kernel_size, pool_strides,
|
||||
strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
|
||||
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
||||
self.conv2d_weight = Parameter(conv2d_weight, "w1")
|
||||
self.max_pool = P.MaxPool(kernel_size=pool_kernel_size, strides=pool_strides).shard(strategy2)
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.conv2d(x, self.conv2d_weight)
|
||||
out = self.max_pool(out)
|
||||
return out
|
||||
|
||||
|
||||
class Net2(Cell):
|
||||
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, pool_kernel_size, pool_strides,
|
||||
strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
|
||||
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
||||
self.conv2d_weight = Parameter(conv2d_weight, "w1")
|
||||
self.avg_pool = P.AvgPool(kernel_size=pool_kernel_size, strides=pool_strides).shard(strategy2)
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.conv2d(x, self.conv2d_weight)
|
||||
out = self.avg_pool(out)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
train_net.set_train()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_maxpool_data_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_maxpool_model_parallel1():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((2, 1, 2, 2),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_maxpool_model_parallel2():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((2, 1, 2, 2),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_maxpool_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_avgpool_data_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_avgpool_model_parallel1():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((2, 1, 2, 2),)
|
||||
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_avgpool_model_parallel2():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((2, 1, 2, 2),)
|
||||
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_avgpool_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4)
|
||||
compile_net(net)
|
Loading…
Reference in New Issue