fix-Conv3dTranspose-Resize
This commit is contained in:
parent
7f6424cf81
commit
78b903f090
|
@ -188,6 +188,8 @@ void Conv3dGpuKernelMod::SelectAlgorithm(cudnnTensorDescriptor_t input_descripto
|
||||||
}
|
}
|
||||||
|
|
||||||
void Conv3dGpuKernelMod::SetStrideAndDilation(std::vector<int64_t> stride_me, std::vector<int64_t> dilation_me) {
|
void Conv3dGpuKernelMod::SetStrideAndDilation(std::vector<int64_t> stride_me, std::vector<int64_t> dilation_me) {
|
||||||
|
stride_.clear();
|
||||||
|
dilation_.clear();
|
||||||
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_),
|
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_),
|
||||||
[](const int64_t &value) { return static_cast<int>(value); });
|
[](const int64_t &value) { return static_cast<int>(value); });
|
||||||
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_),
|
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_),
|
||||||
|
|
|
@ -324,6 +324,8 @@ class Conv3dGradFilterGpuKernelMod : public NativeGpuKernelMod {
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetStrideAndDilation(std::vector<int64_t> stride_me, std::vector<int64_t> dilation_me) {
|
void SetStrideAndDilation(std::vector<int64_t> stride_me, std::vector<int64_t> dilation_me) {
|
||||||
|
stride_.clear();
|
||||||
|
dilation_.clear();
|
||||||
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_),
|
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_),
|
||||||
[](const int64_t &value) { return static_cast<int>(value); });
|
[](const int64_t &value) { return static_cast<int>(value); });
|
||||||
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_),
|
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_),
|
||||||
|
|
|
@ -224,6 +224,8 @@ void Conv3dTransposeFwdGpuKernelMod::Set5DDesc(const ShapeVector &input_shape, c
|
||||||
|
|
||||||
void Conv3dTransposeFwdGpuKernelMod::SetStrideAndDilation(std::vector<int64_t> stride_me,
|
void Conv3dTransposeFwdGpuKernelMod::SetStrideAndDilation(std::vector<int64_t> stride_me,
|
||||||
std::vector<int64_t> dilation_me) {
|
std::vector<int64_t> dilation_me) {
|
||||||
|
stride_.clear();
|
||||||
|
dilation_.clear();
|
||||||
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_),
|
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_),
|
||||||
[](const int64_t &value) { return static_cast<int>(value); });
|
[](const int64_t &value) { return static_cast<int>(value); });
|
||||||
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_),
|
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_),
|
||||||
|
|
Loading…
Reference in New Issue