!35905 [assistant][ops][I4ZZUN] New GPU operator implementation, include ResizeArea
Merge pull request !35905 from 黎冠新/ResizeArea
This commit is contained in:
commit
33f12ac66b
|
@ -0,0 +1,112 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/resize_area_gpu_kernel.h"
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
template <typename T>
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateResizeAreaKernelPtr(const std::string &kernel_name,
|
||||
const uint32_t &device_id) {
|
||||
return std::make_unique<cukernel::ResizeAreaHelperGpuKernel<T>>(kernel_name, device_id);
|
||||
}
|
||||
using ResizeAreaPtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, ResizeAreaPtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateResizeAreaKernelPtr<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateResizeAreaKernelPtr<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateResizeAreaKernelPtr<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateResizeAreaKernelPtr<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateResizeAreaKernelPtr<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateResizeAreaKernelPtr<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateResizeAreaKernelPtr<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateResizeAreaKernelPtr<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateResizeAreaKernelPtr<double>}};
|
||||
} // namespace
|
||||
|
||||
bool ResizeAreaGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ResizeAreaGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::ResizeArea>(base_operator);
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
return false;
|
||||
}
|
||||
attr_ptr_->align_corners = kernel_ptr->get_align_corners();
|
||||
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
|
||||
helper_ptr_->SetKernelParam(attr_ptr_);
|
||||
return true;
|
||||
}
|
||||
|
||||
int ResizeAreaGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
for (const auto &input : inputs) {
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
constexpr int64_t kzero = 0;
|
||||
constexpr int64_t kone = 1;
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
input_shapes.emplace_back(inputs[kzero]->GetShapeVector());
|
||||
input_shapes.emplace_back(inputs[kone]->GetShapeVector());
|
||||
output_shapes.emplace_back(outputs[kzero]->GetShapeVector());
|
||||
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
input_size_list_ = helper_ptr_->GetInputSizeList();
|
||||
output_size_list_ = helper_ptr_->GetOutputSizeList();
|
||||
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> ResizeAreaGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, ResizeAreaPtrCreatorFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ResizeArea, ResizeAreaGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IMAGE_RESIZE_AREA_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IMAGE_RESIZE_AREA_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include "mindspore/core/ops/resize_area.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/resize_area_helper.h"
|
||||
#include "kernel/kernel.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class ResizeAreaGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
ResizeAreaGpuKernelMod() { attr_ptr_ = std::make_shared<cukernel::ResizeAreaAttr>(); }
|
||||
~ResizeAreaGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_ = nullptr;
|
||||
std::shared_ptr<cukernel::ResizeAreaAttr> attr_ptr_{nullptr};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_IMAGE_RESIZE_AREA_GPU_KERNEL_H_
|
|
@ -0,0 +1,131 @@
|
|||
/**
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_RESIZE_AREA_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_RESIZE_AREA_HELPER_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/resize_area_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cukernel {
|
||||
class ResizeAreaAttr : public GpuKernelAttrBase {
|
||||
public:
|
||||
ResizeAreaAttr() = default;
|
||||
~ResizeAreaAttr() override = default;
|
||||
bool align_corners;
|
||||
};
|
||||
constexpr size_t INPUT_NUM_T = 1;
|
||||
constexpr size_t INPUT_NUM_SIZE = 1;
|
||||
constexpr size_t OUTPUT_NUM = 1;
|
||||
constexpr size_t WORK_NUM = 2;
|
||||
constexpr size_t SHAPE_SIZE = 4;
|
||||
constexpr int64_t kzero = 0;
|
||||
constexpr int64_t kone = 1;
|
||||
constexpr int64_t ktwo = 2;
|
||||
constexpr int64_t kthree = 3;
|
||||
|
||||
template <typename T>
|
||||
class ResizeAreaHelperGpuKernel : public GpuKernelHelperBase {
|
||||
public:
|
||||
explicit ResizeAreaHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||
: GpuKernelHelperBase(kernel_name, device_id) {
|
||||
align_corners_ = false;
|
||||
is_null_input_ = false;
|
||||
}
|
||||
virtual ~ResizeAreaHelperGpuKernel() = default;
|
||||
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
|
||||
const std::vector<std::vector<int64_t>> &output_shapes) override {
|
||||
ResetResource();
|
||||
std::vector<std::vector<int64_t>> input_shapes_T, input_shapes_size;
|
||||
input_shapes_T.emplace_back(input_shapes[kzero]);
|
||||
input_shapes_size.emplace_back(input_shapes[kone]);
|
||||
int in_flag1 =
|
||||
CalShapesSizeInBytes<T>(input_shapes_T, INPUT_NUM_T, kernel_name_, "input_shapes_images", &input_size_list_);
|
||||
if (in_flag1 != 0) {
|
||||
return in_flag1;
|
||||
}
|
||||
int in_flag2 = CalShapesSizeInBytes<int32_t>(input_shapes_size, INPUT_NUM_SIZE, kernel_name_, "input_shapes_size",
|
||||
&input_size_list_);
|
||||
if (in_flag2 != 0) {
|
||||
return in_flag2;
|
||||
}
|
||||
int out_flag =
|
||||
CalShapesSizeInBytes<float>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
|
||||
if (out_flag != 0) {
|
||||
return out_flag;
|
||||
}
|
||||
is_null_input_ = (in_flag1 == 1 || in_flag2 == 1 || out_flag == 1);
|
||||
|
||||
batch_size_ = input_shapes_T[kzero][kzero];
|
||||
in_height_ = input_shapes_T[kzero][kone];
|
||||
in_width_ = input_shapes_T[kzero][ktwo];
|
||||
channels_ = input_shapes_T[kzero][kthree];
|
||||
out_height_ = output_shapes[kzero][kone];
|
||||
out_width_ = output_shapes[kzero][ktwo];
|
||||
size_t workspace_x_size = out_width_ * sizeof(ResizeAreaCachedInterpolation);
|
||||
size_t workspace_y_size = out_height_ * sizeof(ResizeAreaCachedInterpolation);
|
||||
work_size_list_.emplace_back(workspace_x_size);
|
||||
work_size_list_.emplace_back(workspace_y_size);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
|
||||
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return 0;
|
||||
}
|
||||
T *image_ptr = nullptr;
|
||||
float *output_ptr = nullptr;
|
||||
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &image_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
ResizeAreaCachedInterpolation *x_interps = nullptr;
|
||||
ResizeAreaCachedInterpolation *y_interps = nullptr;
|
||||
x_interps = reinterpret_cast<ResizeAreaCachedInterpolation *>(work_ptrs[0]);
|
||||
y_interps = reinterpret_cast<ResizeAreaCachedInterpolation *>(work_ptrs[1]);
|
||||
flag = GetDeviceAddress<float>(output_ptrs, 0, kernel_name_, &output_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
align_corners_ = attr_ptr_->align_corners;
|
||||
CalResizeArea(image_ptr, x_interps, y_interps, output_ptr, batch_size_, channels_, out_height_, out_width_,
|
||||
in_height_, in_width_, align_corners_, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
return 0;
|
||||
}
|
||||
|
||||
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
|
||||
attr_ptr_ = std::dynamic_pointer_cast<ResizeAreaAttr>(kernel_attr);
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<ResizeAreaAttr> attr_ptr_;
|
||||
int post_output_size_;
|
||||
int32_t batch_size_;
|
||||
int32_t in_height_;
|
||||
int32_t in_width_;
|
||||
int32_t channels_;
|
||||
int32_t out_height_;
|
||||
int32_t out_width_;
|
||||
bool align_corners_;
|
||||
bool is_null_input_;
|
||||
};
|
||||
} // namespace cukernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_RESIZE_AREA_HELPER_H_
|
|
@ -0,0 +1,275 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include "resize_area_impl.cuh"
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
|
||||
__device__ int32_t BoundResizeArea(int32_t val, int32_t limit) { return min(limit - 1, max(int32_t{0}, val)); }
|
||||
|
||||
__global__ void ComputeInterpolation(ResizeAreaCachedInterpolation *x_interps,
|
||||
ResizeAreaCachedInterpolation *y_interps, const int32_t in_height,
|
||||
const int32_t in_width, const int32_t out_height, const int32_t out_width,
|
||||
float height_scale, float width_scale, bool align_corners) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
pos < out_height + out_width; pos += gridDim.x * blockDim.x) {
|
||||
const int32_t pair_in_length[2] = {in_height, in_width};
|
||||
ResizeAreaCachedInterpolation *pair_interps[2] = {y_interps, x_interps};
|
||||
float pair_scale[2] = {height_scale, width_scale};
|
||||
int32_t in_length;
|
||||
ResizeAreaCachedInterpolation *interps;
|
||||
size_t offset;
|
||||
float scale;
|
||||
size_t mode = min(size_t(1), pos/out_height);
|
||||
in_length = pair_in_length[mode];
|
||||
interps = pair_interps[mode];
|
||||
scale = pair_scale[mode];
|
||||
offset = mode > 0 ? pos-out_height : pos;
|
||||
float transit_0 = offset * scale;
|
||||
float transit_1 = (offset + 1) * scale;
|
||||
size_t v = floor(transit_0);
|
||||
interps[offset].start = v;
|
||||
interps[offset].start_scale = v < transit_0 ? (v + 1 > transit_1 ? scale : v + 1 - transit_0)
|
||||
: (v + 1 > transit_1 ? transit_1 - v : 1.0);
|
||||
v = ceil(transit_1);
|
||||
interps[offset].end = v;
|
||||
v = interps[offset].end - 1;
|
||||
interps[offset].end_minus_one_scale = v < transit_0 ? (v + 1 > transit_1 ? scale : v + 1 - transit_0)
|
||||
: (v + 1 > transit_1 ? transit_1 - v : 1.0);
|
||||
interps[offset].needs_bounding = BoundResizeArea(interps[offset].start, in_length) != interps[offset].start ||
|
||||
BoundResizeArea(interps[offset].end, in_length) != (interps[offset].end - 1);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
__global__ void PatchSum(const T *images, const ResizeAreaCachedInterpolation *x_interps,
|
||||
const ResizeAreaCachedInterpolation *y_interps, float *output, int32_t batch_size,
|
||||
const int32_t channels, const int32_t out_height, const int32_t out_width, const float scale,
|
||||
const int32_t in_height, const int32_t in_width) {
|
||||
#define BOUND_IF_NEEDED(x, y, NeedsBounding) (NeedsBounding ? BoundResizeArea(x, y) : (x))
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
pos < batch_size * channels * out_height * out_width; pos += gridDim.x * blockDim.x) {
|
||||
size_t tem_pos = pos;
|
||||
size_t image_id = tem_pos / (channels * out_height * out_width);
|
||||
tem_pos = tem_pos % (channels * out_height * out_width);
|
||||
size_t y_id = tem_pos / (out_width * channels);
|
||||
tem_pos = tem_pos % (out_width * channels);
|
||||
size_t x_id = tem_pos / channels;
|
||||
size_t channel_id = tem_pos % channels;
|
||||
ResizeAreaCachedInterpolation x_interp = x_interps[x_id];
|
||||
ResizeAreaCachedInterpolation y_interp = y_interps[y_id];
|
||||
size_t start_offset, y_offset;
|
||||
start_offset = image_id * in_height * in_width * channels;
|
||||
float sum = 0;
|
||||
|
||||
y_offset = start_offset + BOUND_IF_NEEDED(y_interp.start, in_height, y_interp.needs_bounding) * in_width * channels;
|
||||
float scale_x = x_interp.start_scale;
|
||||
float sum_y = static_cast<float>(images[y_offset + channels * BOUND_IF_NEEDED(
|
||||
x_interp.start, in_width, x_interp.needs_bounding) + channel_id]) * scale_x;
|
||||
if (x_interp.start +1 != x_interp.end) {
|
||||
for (size_t x = x_interp.start + 1; x < x_interp.end - 1; ++x) {
|
||||
sum_y += static_cast<float>(images[y_offset + channels * BOUND_IF_NEEDED(
|
||||
x, in_width, x_interp.needs_bounding) + channel_id]);
|
||||
}
|
||||
scale_x = x_interp.end_minus_one_scale;
|
||||
sum_y += static_cast<float>(images[y_offset + channels * BOUND_IF_NEEDED(
|
||||
x_interp.end - 1, in_width, x_interp.needs_bounding) + channel_id]) * scale_x;
|
||||
}
|
||||
sum += sum_y * y_interp.start_scale;
|
||||
if (y_interp.start + 1 != y_interp.end) {
|
||||
for (size_t y = y_interp.start + 1; y < y_interp.end -1; ++y) {
|
||||
y_offset = start_offset + BOUND_IF_NEEDED(y, in_height, y_interp.needs_bounding) * in_width * channels;
|
||||
scale_x = x_interp.start_scale;
|
||||
sum_y = static_cast<float>(images[y_offset + channels * BOUND_IF_NEEDED(
|
||||
x_interp.start, in_width, x_interp.needs_bounding) + channel_id]) * scale_x;
|
||||
if (x_interp.start +1 != x_interp.end) {
|
||||
for (size_t x = x_interp.start + 1; x < x_interp.end - 1; ++x) {
|
||||
sum_y += static_cast<float>(images[y_offset + channels * BOUND_IF_NEEDED(
|
||||
x, in_width, x_interp.needs_bounding) + channel_id]);
|
||||
}
|
||||
scale_x = x_interp.end_minus_one_scale;
|
||||
sum_y += static_cast<float>(images[y_offset + channels * BOUND_IF_NEEDED(
|
||||
x_interp.end - 1, in_width, x_interp.needs_bounding) + channel_id]) * scale_x;
|
||||
}
|
||||
sum += sum_y;
|
||||
}
|
||||
y_offset = start_offset +
|
||||
BOUND_IF_NEEDED(y_interp.end - 1, in_height, y_interp.needs_bounding) * in_width * channels;
|
||||
scale_x = x_interp.start_scale;
|
||||
sum_y = static_cast<float>(images[y_offset + channels *
|
||||
BOUND_IF_NEEDED(x_interp.start, in_width, x_interp.needs_bounding) + channel_id]) * scale_x;
|
||||
if (x_interp.start +1 != x_interp.end) {
|
||||
for (size_t x = x_interp.start + 1; x < x_interp.end - 1; ++x) {
|
||||
sum_y += static_cast<float>(images[y_offset + channels *
|
||||
BOUND_IF_NEEDED(x, in_width, x_interp.needs_bounding) + channel_id]);
|
||||
}
|
||||
scale_x = x_interp.end_minus_one_scale;
|
||||
sum_y += static_cast<float>(images[y_offset + channels * BOUND_IF_NEEDED(
|
||||
x_interp.end - 1, in_width, x_interp.needs_bounding) + channel_id]) * scale_x;
|
||||
}
|
||||
sum += sum_y * y_interp.end_minus_one_scale;
|
||||
}
|
||||
output[pos] = sum * scale;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// half
|
||||
template <>
|
||||
__global__ void PatchSum(const half *images, const ResizeAreaCachedInterpolation *x_interps,
|
||||
const ResizeAreaCachedInterpolation *y_interps, float *output, int32_t batch_size,
|
||||
const int32_t channels, const int32_t out_height, const int32_t out_width, const float scale,
|
||||
const int32_t in_height, const int32_t in_width) {
|
||||
#define BOUND_IF_NEEDED(x, y, NeedsBounding) (NeedsBounding ? BoundResizeArea(x, y) : (x))
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
pos < batch_size * channels * out_height * out_width; pos += gridDim.x * blockDim.x) {
|
||||
size_t tem_pos = pos;
|
||||
size_t image_id = tem_pos / (channels * out_height * out_width);
|
||||
tem_pos = tem_pos % (channels * out_height * out_width);
|
||||
size_t y_id = tem_pos / (out_width * channels);
|
||||
tem_pos = tem_pos % (out_width * channels);
|
||||
size_t x_id = tem_pos / channels;
|
||||
size_t channel_id = tem_pos % channels;
|
||||
ResizeAreaCachedInterpolation x_interp = x_interps[x_id];
|
||||
ResizeAreaCachedInterpolation y_interp = y_interps[y_id];
|
||||
size_t start_offset, y_offset;
|
||||
start_offset = image_id * in_height * in_width * channels;
|
||||
float sum = 0;
|
||||
y_offset = start_offset +
|
||||
BOUND_IF_NEEDED(y_interp.start, in_height, y_interp.needs_bounding) * in_width * channels;
|
||||
float scale_x = x_interp.start_scale;
|
||||
float sum_y = __half2float(images[y_offset + channels * BOUND_IF_NEEDED(
|
||||
x_interp.start, in_width, x_interp.needs_bounding) + channel_id]) * scale_x;
|
||||
if (x_interp.start +1 != x_interp.end) {
|
||||
for (size_t x = x_interp.start + 1; x < x_interp.end - 1; ++x) {
|
||||
sum_y += __half2float(images[y_offset + channels *
|
||||
BOUND_IF_NEEDED(x, in_width, x_interp.needs_bounding) + channel_id]);
|
||||
}
|
||||
scale_x = x_interp.end_minus_one_scale;
|
||||
sum_y += __half2float(images[y_offset + channels * BOUND_IF_NEEDED(
|
||||
x_interp.end - 1, in_width, x_interp.needs_bounding) + channel_id]) * scale_x;
|
||||
}
|
||||
sum += sum_y * y_interp.start_scale;
|
||||
if (y_interp.start + 1 != y_interp.end) {
|
||||
for (size_t y = y_interp.start + 1; y < y_interp.end -1; ++y) {
|
||||
y_offset = start_offset + BOUND_IF_NEEDED(y, in_height, y_interp.needs_bounding) * in_width * channels;
|
||||
scale_x = x_interp.start_scale;
|
||||
sum_y = __half2float(images[y_offset + channels * BOUND_IF_NEEDED(
|
||||
x_interp.start, in_width, x_interp.needs_bounding) + channel_id]) * scale_x;
|
||||
if (x_interp.start +1 != x_interp.end) {
|
||||
for (size_t x = x_interp.start + 1; x < x_interp.end - 1; ++x) {
|
||||
sum_y += __half2float(images[y_offset +
|
||||
channels * BOUND_IF_NEEDED(x, in_width, x_interp.needs_bounding) + channel_id]);
|
||||
}
|
||||
scale_x = x_interp.end_minus_one_scale;
|
||||
sum_y += __half2float(images[y_offset + channels * BOUND_IF_NEEDED(
|
||||
x_interp.end - 1, in_width, x_interp.needs_bounding) + channel_id]) * scale_x;
|
||||
}
|
||||
sum += sum_y;
|
||||
}
|
||||
y_offset = start_offset +
|
||||
BOUND_IF_NEEDED(y_interp.end - 1, in_height, y_interp.needs_bounding) * in_width * channels;
|
||||
scale_x = x_interp.start_scale;
|
||||
sum_y = __half2float(images[y_offset + channels * BOUND_IF_NEEDED(
|
||||
x_interp.start, in_width, x_interp.needs_bounding) + channel_id]) * scale_x;
|
||||
if (x_interp.start +1 != x_interp.end) {
|
||||
for (size_t x = x_interp.start + 1; x < x_interp.end - 1; ++x) {
|
||||
sum_y += __half2float(images[y_offset + channels *
|
||||
BOUND_IF_NEEDED(x, in_width, x_interp.needs_bounding) + channel_id]);
|
||||
}
|
||||
scale_x = x_interp.end_minus_one_scale;
|
||||
sum_y += __half2float(images[y_offset + channels * BOUND_IF_NEEDED(
|
||||
x_interp.end - 1, in_width, x_interp.needs_bounding) + channel_id]) * scale_x;
|
||||
}
|
||||
sum += sum_y * y_interp.end_minus_one_scale;
|
||||
}
|
||||
output[pos] = sum * scale;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
float Scaling(size_t in_size, size_t out_size, bool align_corners) {
|
||||
return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
|
||||
: in_size / static_cast<float>(out_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalResizeArea(const T *images, ResizeAreaCachedInterpolation *x_interps, ResizeAreaCachedInterpolation *y_interps,
|
||||
float *output, int32_t batch_size, const int32_t channels, const int32_t out_height,
|
||||
const int32_t out_width, const int32_t in_height, const int32_t in_width, bool align_corners,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
float height_scale = Scaling(in_height, out_height, align_corners);
|
||||
float width_scale = Scaling(in_width, out_width, align_corners);
|
||||
float scale = 1.0 / (height_scale * width_scale);
|
||||
ComputeInterpolation<<<CUDA_BLOCKS(device_id, out_height + out_width), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
x_interps, y_interps, in_height, in_width, out_height, out_width, height_scale, width_scale,
|
||||
align_corners);
|
||||
PatchSum<<<CUDA_BLOCKS(device_id, batch_size * out_height * out_width * channels), CUDA_THREADS(device_id), 0,
|
||||
cuda_stream>>>(images, x_interps, y_interps, output, batch_size, channels, out_height, out_width, scale,
|
||||
in_height, in_width);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
template CUDA_LIB_EXPORT void CalResizeArea<int8_t>(const int8_t *images,
|
||||
ResizeAreaCachedInterpolation *x_interps, ResizeAreaCachedInterpolation *y_interps, float *output,
|
||||
int32_t batch_size, const int32_t channels, const int32_t out_height, const int32_t out_width,
|
||||
const int32_t in_height, const int32_t in_width, bool align_corners, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalResizeArea<uint8_t>(const uint8_t *images,
|
||||
ResizeAreaCachedInterpolation *x_interps, ResizeAreaCachedInterpolation *y_interps, float *output,
|
||||
int32_t batch_size, const int32_t channels, const int32_t out_height, const int32_t out_width,
|
||||
const int32_t in_height, const int32_t in_width, bool align_corners, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalResizeArea<int16_t>(const int16_t *images,
|
||||
ResizeAreaCachedInterpolation *x_interps, ResizeAreaCachedInterpolation *y_interps, float *output,
|
||||
int32_t batch_size, const int32_t channels, const int32_t out_height, const int32_t out_width,
|
||||
const int32_t in_height, const int32_t in_width, bool align_corners, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalResizeArea<uint16_t>(const uint16_t *images,
|
||||
ResizeAreaCachedInterpolation *x_interps, ResizeAreaCachedInterpolation *y_interps, float *output,
|
||||
int32_t batch_size, const int32_t channels, const int32_t out_height, const int32_t out_width,
|
||||
const int32_t in_height, const int32_t in_width, bool align_corners, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalResizeArea<int32_t>(const int32_t *images,
|
||||
ResizeAreaCachedInterpolation *x_interps, ResizeAreaCachedInterpolation *y_interps, float *output,
|
||||
int32_t batch_size, const int32_t channels, const int32_t out_height, const int32_t out_width,
|
||||
const int32_t in_height, const int32_t in_width, bool align_corners, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalResizeArea<int64_t>(const int64_t *images,
|
||||
ResizeAreaCachedInterpolation *x_interps, ResizeAreaCachedInterpolation *y_interps, float *output,
|
||||
int32_t batch_size, const int32_t channels, const int32_t out_height, const int32_t out_width,
|
||||
const int32_t in_height, const int32_t in_width, bool align_corners, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalResizeArea<half>(const half *images,
|
||||
ResizeAreaCachedInterpolation *x_interps, ResizeAreaCachedInterpolation *y_interps, float *output,
|
||||
int32_t batch_size, const int32_t channels, const int32_t out_height, const int32_t out_width,
|
||||
const int32_t in_height, const int32_t in_width, bool align_corners, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalResizeArea<float>(const float *images,
|
||||
ResizeAreaCachedInterpolation *x_interps, ResizeAreaCachedInterpolation *y_interps, float *output,
|
||||
int32_t batch_size, const int32_t channels, const int32_t out_height, const int32_t out_width,
|
||||
const int32_t in_height, const int32_t in_width, bool align_corners, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalResizeArea<double>(const double *images,
|
||||
ResizeAreaCachedInterpolation *x_interps, ResizeAreaCachedInterpolation *y_interps, float *output,
|
||||
int32_t batch_size, const int32_t channels, const int32_t out_height, const int32_t out_width,
|
||||
const int32_t in_height, const int32_t in_width, bool align_corners, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RESIZE_AREA_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RESIZE_AREA_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
#include "include/cuda_fp16.h"
|
||||
#include "mindapi/base/types.h"
|
||||
struct ResizeAreaCachedInterpolation {
|
||||
size_t start;
|
||||
size_t end;
|
||||
float start_scale;
|
||||
float end_minus_one_scale;
|
||||
bool needs_bounding = true;
|
||||
};
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalResizeArea(const T *images, ResizeAreaCachedInterpolation *x_interps,
|
||||
ResizeAreaCachedInterpolation *y_interps, float *output, int32_t batch_size, const int32_t channels,
|
||||
const int32_t out_height, const int32_t out_width, const int32_t in_height, const int32_t in_width,
|
||||
bool align_corners, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RESIZE_AREA_IMPL_CUH_
|
||||
|
|
@ -30,6 +30,7 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr ResizeAreaInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr int64_t size_num = 2;
|
||||
constexpr int64_t indexid2 = 2;
|
||||
constexpr int64_t indexid3 = 3;
|
||||
constexpr int64_t image_shape_size = 4;
|
||||
constexpr int64_t size_shape_size = 1;
|
||||
|
@ -73,6 +74,18 @@ abstract::ShapePtr ResizeAreaInferShape(const PrimitivePtr &primitive, const std
|
|||
}
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
} else {
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
|
||||
auto x_shape = x_shape_ptr->shape();
|
||||
if (x_shape_ptr->IsDynamic()) {
|
||||
auto x_min_shape = x_shape_ptr->min_shape();
|
||||
auto x_max_shape = x_shape_ptr->max_shape();
|
||||
x_min_shape[1] = 0;
|
||||
x_min_shape[indexid2] = 0;
|
||||
x_max_shape[1] = 1;
|
||||
x_max_shape[indexid2] = 1;
|
||||
return std::make_shared<abstract::Shape>(x_shape, x_min_shape, x_max_shape);
|
||||
}
|
||||
ShapeVector out_shape = {images_shape[0], abstract::Shape::SHP_ANY, abstract::Shape::SHP_ANY,
|
||||
images_shape[indexid3]};
|
||||
ShapeVector shape_min = {images_shape[0], 0, 0, images_shape[indexid3]};
|
||||
|
@ -84,7 +97,12 @@ TypePtr ResizeAreaInferType(const PrimitivePtr &primitive, const std::vector<Abs
|
|||
return kFloat32;
|
||||
}
|
||||
} // namespace
|
||||
MIND_API_BASE_IMPL(ResizeArea, PrimitiveC, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(ResizeArea, BaseOperator);
|
||||
void ResizeArea::Init(const bool align_corners) { this->set_align_corners(align_corners); }
|
||||
void ResizeArea::set_align_corners(const bool align_corners) {
|
||||
(void)this->AddAttr(kAlignCorners, api::MakeValue(align_corners));
|
||||
}
|
||||
bool ResizeArea::get_align_corners() const { return GetValue<bool>(GetAttr(kAlignCorners)); }
|
||||
AbstractBasePtr ResizeAreaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -32,6 +32,9 @@ class MIND_API ResizeArea : public BaseOperator {
|
|||
public:
|
||||
MIND_API_BASE_MEMBER(ResizeArea);
|
||||
ResizeArea() : BaseOperator(kNameResizeArea) { InitIOName({"images", "size"}, {"y"}); }
|
||||
void Init(const bool align_corners = false);
|
||||
void set_align_corners(const bool align_corners);
|
||||
bool get_align_corners() const;
|
||||
};
|
||||
abstract::AbstractBasePtr ResizeAreaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
|
|
@ -866,7 +866,7 @@ class ResizeArea(Primitive):
|
|||
ValueError: The size is not positive.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> images = Tensor([[[[2], [4], [6], [8]], [[10], [12], [14], [16]]]], mindspore.float16)
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.ops.operations.image_ops as ops
|
||||
|
||||
|
||||
class NetResizeArea(nn.Cell):
|
||||
|
||||
def __init__(self, align_corners=False):
|
||||
super(NetResizeArea, self).__init__()
|
||||
self.resize_area = ops.ResizeArea(align_corners=align_corners)
|
||||
|
||||
def construct(self, images, size):
|
||||
return self.resize_area(images, size)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_resize_area_float16():
|
||||
"""
|
||||
Feature: Input type of float16
|
||||
Description: Input type of [float16, int32].
|
||||
Expectation: success.
|
||||
"""
|
||||
for mode in [context.PYNATIVE_MODE]:
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
data_type = np.half
|
||||
images_array = np.array([[[[1.2, 120], [140, 2.1], [40, 0.12]],
|
||||
[[1.35, 1.2], [0.04, 5], [10, 4]]],
|
||||
[[[34, 10.2], [1.05, 12.1], [3, 0.06]],
|
||||
[[65, 23], [14, 4.3], [2.2, 4]]]]).astype(data_type)
|
||||
size_array = np.array([2, 2]).astype(np.int32)
|
||||
align_corners = True
|
||||
|
||||
images_ms = Tensor(images_array)
|
||||
size_ms = Tensor(size_array)
|
||||
resize_area = NetResizeArea(align_corners=align_corners)
|
||||
output_ms = resize_area(images_ms, size_ms)
|
||||
expect = np.array([[[[7.0600098e+01, 6.1049805e+01], [4.0000000e+01, 1.1999512e-01]],
|
||||
[[6.9480896e-01, 3.1000977e+00], [1.0000000e+01, 4.0000000e+00]]],
|
||||
[[[1.7524902e+01, 1.1152344e+01], [3.0000000e+00, 5.9997559e-02]],
|
||||
[[3.9500000e+01, 1.3650391e+01], [2.1992188e+00, 4.0000000e+00]]]]).astype(np.float32)
|
||||
assert np.allclose(output_ms.asnumpy(),
|
||||
expect,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
equal_nan=False)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_resize_area_float32():
|
||||
"""
|
||||
Feature: Input type of float32
|
||||
Description: Input type of [float32, int32].
|
||||
Expectation: success.
|
||||
"""
|
||||
for mode in [context.PYNATIVE_MODE]:
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
data_type = np.float32
|
||||
images_array = np.array([[[[1.2, 0.1, 120], [0.140, 2.5, 21]], [[135, 0.02, 2], [0.00102, 4.1, 3]]],
|
||||
[[[4, 0.01, 1.02], [21, 0.15, 11]], [[65, 2.1, 23], [22, 1.2, 4]]]]).astype(data_type)
|
||||
size_array = np.array([3, 3]).astype(np.int32)
|
||||
align_corners = False
|
||||
|
||||
images_ms = Tensor(images_array)
|
||||
size_ms = Tensor(size_array)
|
||||
resize_area = NetResizeArea(align_corners=align_corners)
|
||||
output_ms = resize_area(images_ms, size_ms)
|
||||
expect = np.array([[[[1.2000000e+00, 1.0000000e-01, 1.1999999e+02],
|
||||
[6.6999996e-01, 1.3000001e+00, 7.0499992e+01],
|
||||
[1.3999999e-01, 2.4999995e+00, 2.0999996e+01]],
|
||||
[[6.8099998e+01, 5.9999995e-02, 6.0999989e+01],
|
||||
[3.4085251e+01, 1.6800001e+00, 3.6499989e+01],
|
||||
[7.0509978e-02, 3.2999995e+00, 1.1999997e+01]],
|
||||
[[1.3499998e+02, 1.9999996e-02, 1.9999996e+00],
|
||||
[6.7500488e+01, 2.0599999e+00, 2.4999995e+00],
|
||||
[1.0199997e-03, 4.0999990e+00, 2.9999993e+00]]],
|
||||
[[[3.9999998e+00, 9.9999988e-03, 1.0200000e+00],
|
||||
[1.2500000e+01, 8.0000013e-02, 6.0100002e+00],
|
||||
[2.0999996e+01, 1.4999999e-01, 1.0999999e+01]],
|
||||
[[3.4500004e+01, 1.0549999e+00, 1.2010000e+01],
|
||||
[2.8000000e+01, 8.6499995e-01, 9.7550001e+00],
|
||||
[2.1499998e+01, 6.7500001e-01, 7.4999986e+00]],
|
||||
[[6.4999992e+01, 2.0999997e+00, 2.2999998e+01],
|
||||
[4.3499992e+01, 1.6499997e+00, 1.3499997e+01],
|
||||
[2.1999996e+01, 1.1999998e+00, 3.9999990e+00]]]]).astype(np.float32)
|
||||
assert np.allclose(output_ms.asnumpy(),
|
||||
expect,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
equal_nan=False)
|
Loading…
Reference in New Issue