forked from mindspore-Ecosystem/mindspore
!48205 BatchNorm1D 支持 3-D input
Merge pull request !48205 from haozhang/batch_norm1d
This commit is contained in:
commit
b81d93bbd7
|
@ -3,7 +3,7 @@ mindspore.nn.BatchNorm1d
|
|||
|
||||
.. py:class:: mindspore.nn.BatchNorm1d(num_features, eps=1e-5, momentum=0.9, affine=True, gamma_init='ones', beta_init='zeros', moving_mean_init='zeros', moving_var_init='ones', use_batch_statistics=None, data_format='NCHW')
|
||||
|
||||
在二维输入(mini-batch 一维输入)上应用批归一化(Batch Normalization Layer),避免内部协变量偏移。归一化在卷积网络中被广泛的应用。请见论文 `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_ 。
|
||||
在二维或三维输入(mini-batch 一维输入或二维输入)上应用批归一化(Batch Normalization Layer),避免内部协变量偏移。归一化在卷积网络中被广泛的应用。请见论文 `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_ 。
|
||||
|
||||
使用mini-batch数据和学习参数进行训练,计算公式如下。
|
||||
|
||||
|
@ -14,7 +14,7 @@ mindspore.nn.BatchNorm1d
|
|||
BatchNorm的实现在图模式和PyNative模式下是不同的,因此不建议在网络初始化后更改其模式。
|
||||
|
||||
参数:
|
||||
- **num_features** (int) - 通道数量,输入Tensor shape :math:`(N, C)` 中的 `C` 。
|
||||
- **num_features** (int) - 特征数量或输入 `x` 中的通道数量 `C` 。
|
||||
- **eps** (float) - :math:`\epsilon` 加在分母上的值,以确保数值稳定。默认值:1e-5。
|
||||
- **momentum** (float) - 动态均值和动态方差所使用的动量。默认值:0.9。
|
||||
- **affine** (bool) - bool类型。设置为True时,可学习到 :math:`\gamma` 和 :math:`\beta` 值。默认值:True。
|
||||
|
@ -26,10 +26,10 @@ mindspore.nn.BatchNorm1d
|
|||
- **data_format** (str) - 数据格式可为'NHWC'或'NCHW'。默认值:'NCHW'。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor) - 输入shape为 :math:`(N, C_{in})` 的Tensor。
|
||||
- **x** (Tensor) - 输入shape为 :math:`(N, C)` 或 :math:`(N, C, L)` 的Tensor,其中 `N` 为batch, `C` 为特征数量或通道数量, `L` 为序列长度。
|
||||
|
||||
输出:
|
||||
Tensor,归一化后的Tensor,shape为 :math:`(N, C_{out})` 。
|
||||
Tensor,归一化后的Tensor,shape为 :math:`(N, C)` 或 :math:`(N, C, L)` 。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `num_features` 不是整数。
|
||||
|
|
|
@ -36,10 +36,10 @@ mindspore.nn.BatchNorm2d
|
|||
- **data_format** (str) - 数据格式可为'NHWC'或'NCHW'。默认值:'NCHW'。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor) - 输入shape为 :math:`(N,C_{in},H_{in},W_{in})` 的Tensor。
|
||||
- **x** (Tensor) - 输入shape为 :math:`(N, C, H, W)` 的Tensor。
|
||||
|
||||
输出:
|
||||
Tensor,归一化后的Tensor,shape为 :math:`(N,C_{out},H_{out},W_{out})` 。
|
||||
Tensor,归一化后的Tensor,shape为 :math:`(N, C, H, W)` 。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `num_features` 不是整数。
|
||||
|
|
|
@ -20,7 +20,7 @@ mindspore.ops.BatchNorm
|
|||
- **is_training** (bool) - 如果 `is_training` 为True,则在训练期间计算 `mean` 和 `variance`。如果 `is_training` 为False,则在推理期间从checkpoint加载。默认值:False。
|
||||
- **epsilon** (float) - 添加到分母上的值,以确保数值稳定性。默认值:1e-5。
|
||||
- **momentum** (float) - 动态均值和动态方差所使用的动量。(例如 :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`)。动量值必须为[0, 1]。默认值:0.1。
|
||||
- **data_format** (str) - 输入数据格式,可选值有:'NHWC'或'NCHW'。默认值:'NCHW'。
|
||||
- **data_format** (str) - 输入数据格式,可选值有:'NHWC'或'NCHW','NHWC'仅在GPU上支持。默认值:'NCHW'。
|
||||
|
||||
输入:
|
||||
如果 `is_training` 为False,则输入为多个Tensor。
|
||||
|
|
|
@ -5,9 +5,6 @@ mindspore.ops.ResizeBicubic
|
|||
|
||||
使用双三次插值调整图像大小到指定的大小。
|
||||
|
||||
.. warning::
|
||||
输出最大长度为1000000。
|
||||
|
||||
参数:
|
||||
- **align_corners** (bool,可选) - 如果为True,则输入输出图像四个角像素的中心被对齐,同时保留角像素处的值。默认值:False。
|
||||
- **half_pixel_centers** (bool,可选) - 是否使用半像素中心对齐。如果设置为True,那么 `align_corners` 应该设置为False。默认值:False。
|
||||
|
|
|
@ -25,23 +25,23 @@ namespace kernel {
|
|||
namespace {
|
||||
constexpr size_t kBatchNormInputsNum = 5;
|
||||
constexpr size_t kBatchNormOutputsNum = 5;
|
||||
constexpr size_t kBatchNormInputShapeSize = 4;
|
||||
constexpr size_t kBatchNormInputShapeSize2 = 2;
|
||||
constexpr size_t kBatchNormInputShapeMaxSize = 4;
|
||||
} // namespace
|
||||
|
||||
bool BatchNormCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::BatchNorm>(base_operator);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(ERROR) << "cast BatchNorm ops failed!";
|
||||
return false;
|
||||
}
|
||||
auto kernel_name = kernel_ptr->GetPrim()->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
|
||||
kernel_name_ = kernel_ptr->GetPrim()->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
bool is_match = MatchKernelAttr(kernel_attr, GetOpSupport()).first;
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << kernel_name << " does not support this kernel data type: " << kernel_attr;
|
||||
MS_LOG(EXCEPTION) << kernel_name_ << " does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
base_operator_ = base_operator;
|
||||
|
@ -60,11 +60,7 @@ int BatchNormCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
|
|||
}
|
||||
|
||||
auto x_shape = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
if (x_shape.size() == kBatchNormInputShapeSize2) {
|
||||
(void)x_shape.insert(x_shape.end(), kBatchNormInputShapeSize - kBatchNormInputShapeSize2, 1);
|
||||
} else if (x_shape.size() != kBatchNormInputShapeSize) {
|
||||
MS_LOG(EXCEPTION) << "Batchnorm only support nchw input!";
|
||||
}
|
||||
(void)x_shape.insert(x_shape.end(), kBatchNormInputShapeMaxSize - x_shape.size(), 1);
|
||||
|
||||
batch_size_ = x_shape[0];
|
||||
channel_ = x_shape[1];
|
||||
|
|
|
@ -25,27 +25,27 @@ namespace kernel {
|
|||
namespace {
|
||||
constexpr size_t kBatchNormGradInputsNum = 6;
|
||||
constexpr size_t kBatchNormGradOutputsNum = 3;
|
||||
constexpr size_t kBatchNormGradInputShapeSize = 4;
|
||||
constexpr size_t kBatchNormGradInputShapeSize2 = 2;
|
||||
constexpr size_t kBatchNormGradInputShapeMaxSize = 4;
|
||||
constexpr size_t kBatchNormGradInputShapeMinSize = 2;
|
||||
constexpr size_t kScaleShiftNum = 2;
|
||||
} // namespace
|
||||
bool BatchNormGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
base_operator_ = base_operator;
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::BatchNormGrad>(base_operator);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(ERROR) << "cast BatchNormGrad ops failed!";
|
||||
return false;
|
||||
}
|
||||
auto kernel_name = kernel_ptr->GetPrim()->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
|
||||
is_train_ = kernel_ptr->get_is_training();
|
||||
epsilon_ = kernel_ptr->get_epsilon();
|
||||
kernel_name_ = kernel_ptr->GetPrim()->name();
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
bool is_match = MatchKernelAttr(kernel_attr, GetOpSupport()).first;
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << kernel_name << " does not support this kernel data type: " << kernel_attr;
|
||||
MS_LOG(EXCEPTION) << kernel_name_ << " does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -59,11 +59,9 @@ int BatchNormGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
|
|||
}
|
||||
|
||||
auto x_shape = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
if (x_shape.size() == NC_LEN) {
|
||||
(void)x_shape.insert(x_shape.end(), (SHAPE_4D - NC_LEN), 1);
|
||||
} else if (x_shape.size() != SHAPE_4D) {
|
||||
MS_LOG(EXCEPTION) << "Fused batchnorm support nc or nchw input!";
|
||||
}
|
||||
const size_t x_shape_size = x_shape.size();
|
||||
(void)x_shape.insert(x_shape.end(), kBatchNormGradInputShapeMaxSize - x_shape_size, 1);
|
||||
|
||||
batch_size_ = x_shape[N];
|
||||
channel_ = x_shape[C];
|
||||
hw_size_ = x_shape[H] * x_shape[W];
|
||||
|
|
|
@ -24,8 +24,8 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
size_t kInputSize2 = 2;
|
||||
size_t kInputSize4 = 4;
|
||||
constexpr size_t kBatchNormInputShapeMaxSize = 4;
|
||||
constexpr size_t kBatchNormInputShapeMinSize = 2;
|
||||
float kExpAvgFactorDefault = 0.1;
|
||||
} // namespace
|
||||
template <typename T>
|
||||
|
@ -73,6 +73,7 @@ bool BatchNormGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
|
||||
bool BatchNormGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
auto kernel_name_ = base_operator->name();
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::BatchNorm>(base_operator);
|
||||
if (kernel_ptr == nullptr) {
|
||||
|
@ -99,6 +100,7 @@ bool BatchNormGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
|
|||
<< ", but got " << kernel_name_;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
|
@ -136,13 +138,21 @@ int BatchNormGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
|
|||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
auto shape = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
if (shape.size() != kInputSize2 && shape.size() != kInputSize4) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input must be 2 or 4, but got "
|
||||
<< shape.size();
|
||||
|
||||
auto x_shape = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
const size_t x_shape_size = x_shape.size();
|
||||
|
||||
auto format = inputs[kIndex0]->GetFormat();
|
||||
if (x_shape_size == kBatchNormInputShapeMinSize) {
|
||||
format = Format::NCHW;
|
||||
} else if (format_ == Format::NHWC) {
|
||||
format = Format::NHWC;
|
||||
}
|
||||
|
||||
if (shape.size() == kInputSize2) {
|
||||
(void)x_shape.insert(x_shape.begin() + (format == Format::NHWC ? kIndex1 : x_shape_size),
|
||||
kBatchNormInputShapeMaxSize - x_shape_size, 1);
|
||||
|
||||
if (x_shape_size == kBatchNormInputShapeMinSize) {
|
||||
mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
|
||||
} else if (is_train_) {
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
|
||||
|
@ -150,12 +160,8 @@ int BatchNormGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
|
|||
mode_ = CUDNN_BATCHNORM_SPATIAL;
|
||||
}
|
||||
|
||||
CheckTensorSize({shape});
|
||||
auto format = inputs[kIndex0]->GetFormat();
|
||||
if (format_ == Format::NHWC) {
|
||||
format = Format::NHWC;
|
||||
}
|
||||
SetTensorDescriptor(format, shape);
|
||||
CheckTensorSize({x_shape});
|
||||
SetTensorDescriptor(format, x_shape);
|
||||
InitSizeLists();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
@ -252,13 +258,7 @@ void BatchNormGpuKernelMod::InitSizeLists() {
|
|||
void BatchNormGpuKernelMod::SetTensorDescriptor(const Format &format, const ShapeVector &shape) {
|
||||
cudnnTensorFormat_t cudnn_format;
|
||||
int batch, channel, height, width;
|
||||
if (shape.size() == kInputSize2) {
|
||||
batch = LongToInt(shape[kIndex0]);
|
||||
channel = LongToInt(shape[kIndex1]);
|
||||
height = 1;
|
||||
width = 1;
|
||||
cudnn_format = CUDNN_TENSOR_NCHW;
|
||||
} else if (format == Format::NHWC) {
|
||||
if (format == Format::NHWC) {
|
||||
batch = LongToInt(shape[kIndex0]);
|
||||
height = LongToInt(shape[kIndex1]);
|
||||
width = LongToInt(shape[kIndex2]);
|
||||
|
|
|
@ -24,11 +24,12 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
size_t kInputSize2 = 2;
|
||||
size_t kInputSize4 = 4;
|
||||
constexpr size_t kBatchNormGradInputShapeMaxSize = 4;
|
||||
constexpr size_t kBatchNormGradInputShapeMinSize = 2;
|
||||
} // namespace
|
||||
bool BatchNormGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::BatchNormGrad>(base_operator);
|
||||
if (kernel_ptr == nullptr) {
|
||||
|
@ -86,30 +87,35 @@ int BatchNormGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
|
|||
return ret;
|
||||
}
|
||||
|
||||
auto shape = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
if (shape.size() != kInputSize2 && shape.size() != kInputSize4) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input must be 2 or 4, but got "
|
||||
<< shape.size();
|
||||
beta_data_diff_ = 0;
|
||||
|
||||
auto x_shape = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
const size_t x_shape_size = x_shape.size();
|
||||
|
||||
auto format = inputs[kIndex0]->GetFormat();
|
||||
if (x_shape_size == kBatchNormGradInputShapeMinSize) {
|
||||
format = Format::NCHW;
|
||||
} else if (format_ == Format::NHWC) {
|
||||
format = Format::NHWC;
|
||||
}
|
||||
|
||||
is_null_input_ = CHECK_SHAPE_NULL(shape, kernel_name_, "input");
|
||||
(void)x_shape.insert(x_shape.begin() + (format == Format::NHWC ? kIndex1 : x_shape_size),
|
||||
kBatchNormGradInputShapeMaxSize - x_shape_size, 1);
|
||||
|
||||
is_null_input_ = CHECK_SHAPE_NULL(x_shape, kernel_name_, "input");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
if (shape.size() == kInputSize2) {
|
||||
|
||||
if (x_shape_size == kBatchNormGradInputShapeMinSize) {
|
||||
mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
|
||||
} else {
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
|
||||
}
|
||||
|
||||
beta_data_diff_ = 0;
|
||||
CheckTensorSize({shape});
|
||||
auto format = inputs[kIndex0]->GetFormat();
|
||||
if (format_ == Format::NHWC) {
|
||||
format = Format::NHWC;
|
||||
}
|
||||
SetTensorDescriptor(format, shape);
|
||||
CheckTensorSize({x_shape});
|
||||
SetTensorDescriptor(format, x_shape);
|
||||
InitSizeLists();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
@ -217,13 +223,7 @@ bool BatchNormGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
|
|||
|
||||
void BatchNormGradGpuKernelMod::SetTensorDescriptor(const Format &format, const ShapeVector &shape) {
|
||||
cudnnTensorFormat_t cudnn_format;
|
||||
if (shape.size() == kInputSize2) {
|
||||
batch_ = LongToInt(shape[kIndex0]);
|
||||
channel_ = LongToInt(shape[kIndex1]);
|
||||
height_ = 1;
|
||||
width_ = 1;
|
||||
cudnn_format = CUDNN_TENSOR_NCHW;
|
||||
} else if (format == Format::NHWC) {
|
||||
if (format == Format::NHWC) {
|
||||
batch_ = LongToInt(shape[kIndex0]);
|
||||
height_ = LongToInt(shape[kIndex1]);
|
||||
width_ = LongToInt(shape[kIndex2]);
|
||||
|
|
|
@ -137,6 +137,8 @@ class BatchNormInfer : public abstract::OpInferBase {
|
|||
|
||||
if (!x_shape_ptr->IsDynamic() && !scale_shape_ptr->IsDynamic()) {
|
||||
// auto format = GetValue<std::string>(primitive->GetAttr(kFormat));
|
||||
(void)CheckAndConvertUtils::CheckInRange("rank of images", SizeToLong(x_shape.size()), kIncludeBoth, {2, 4},
|
||||
prim_name);
|
||||
auto format = get_format_in_infer(primitive);
|
||||
auto channel = format == "NHWC" ? x_shape.back() : x_shape[1];
|
||||
if (scale_shape[kInputIndex0] != channel) {
|
||||
|
|
|
@ -100,6 +100,11 @@ class BatchNormGradInfer : public abstract::OpInferBase {
|
|||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
|
||||
auto x_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
|
||||
if (!IsDynamicRank(x_shape)) {
|
||||
(void)CheckAndConvertUtils::CheckInRange("rank of x", SizeToLong(x_shape.size()), kIncludeBoth, {2, 4},
|
||||
prim_name);
|
||||
}
|
||||
auto scale_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
if (prim_name == kNameBatchNormGradWithAddAndActivation) {
|
||||
return std::make_shared<abstract::TupleShape>(
|
||||
|
|
|
@ -159,7 +159,7 @@ class _BatchNorm(Cell):
|
|||
class BatchNorm1d(_BatchNorm):
|
||||
r"""
|
||||
This layer
|
||||
applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to
|
||||
applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D or 2D inputs) to
|
||||
reduce internal covariate shift. Batch Normalization is widely used in convolutional networks.
|
||||
For the setailed contents, refer to `Batch Normalization: Accelerating Deep Network Training by
|
||||
Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
|
||||
|
@ -174,7 +174,7 @@ class BatchNorm1d(_BatchNorm):
|
|||
recommended to be changed after net was initialized.
|
||||
|
||||
Args:
|
||||
num_features (int): `C` from an expected input of size (N, C).
|
||||
num_features (int): number of features or channels `C` of the input `x` .
|
||||
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
||||
momentum (float): A floating hyperparameter of the momentum for the
|
||||
running_mean and running_var computation. Default: 0.9.
|
||||
|
@ -195,10 +195,11 @@ class BatchNorm1d(_BatchNorm):
|
|||
Default: 'NCHW'.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in})`.
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C)` or :math:`(N, C, L)` ,
|
||||
where `N` is the batch size, `C` is the number of features or channels, and `L` is the sequence length.
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`.
|
||||
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C)` or :math:`(N, C, L)` .
|
||||
|
||||
Raises:
|
||||
TypeError: If `num_features` is not an int.
|
||||
|
@ -276,10 +277,10 @@ class BatchNorm2d(_BatchNorm):
|
|||
Default: 'NCHW'.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, H, W)`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `num_features` is not an int.
|
||||
|
|
|
@ -746,9 +746,6 @@ class ResizeBicubic(Primitive):
|
|||
r"""
|
||||
Resize images to size using bicubic interpolation.
|
||||
|
||||
.. warning::
|
||||
The max output length is 1000000.
|
||||
|
||||
Args:
|
||||
align_corners (bool, optional):If true, the centers of the 4 corner pixels of the input
|
||||
and output tensors are aligned, preserving the values at the corner pixels.Default: False.
|
||||
|
@ -757,7 +754,7 @@ class ResizeBicubic(Primitive):
|
|||
|
||||
Inputs:
|
||||
- **images** (Tensor) - The input image must be a 4-D tensor of shape :math:`(batch, channels, height, width)`.
|
||||
The format must be NHWC.
|
||||
The format must be NCHW.
|
||||
Types allowed: int8, int16, int32, int64, float16, float32, float64, uint8, uint16.
|
||||
- **size** (Tensor) - A 1-D tensor of shape [2], with 2 elements: new_height, new_width.
|
||||
Types allowed: int32.
|
||||
|
|
|
@ -1232,8 +1232,8 @@ class BatchNorm(PrimitiveWithInfer):
|
|||
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
|
||||
(e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
|
||||
Momentum value must be [0, 1]. Default: 0.1.
|
||||
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
|
||||
Default: "NCHW".
|
||||
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW', and the 'NHWC' format
|
||||
is only supported in GPU target. Default: "NCHW".
|
||||
|
||||
Inputs:
|
||||
If `is_training` is False, inputs are Tensors.
|
||||
|
|
Loading…
Reference in New Issue