From b1e87594276a77b2d9ff6ffc6338a397356107d6 Mon Sep 17 00:00:00 2001 From: ymy_forever <971180567@qq.com> Date: Wed, 20 Jul 2022 16:24:18 +0800 Subject: [PATCH] [feat] [assistant] [I4ZZV6] add new GPU operator TripletMarginLoss --- .../cuda_class/triplet_margin_loss_helper.h | 516 ++++++++++++++++++ .../cuda_ops/triplet_margin_loss_impl.cu | 508 +++++++++++++++++ .../cuda_ops/triplet_margin_loss_impl.cuh | 37 ++ .../nn/triplet_margin_loss_gpu_kernel.cc | 294 ++++++++++ .../nn/triplet_margin_loss_gpu_kernel.h | 61 +++ mindspore/core/ops/op_name.h | 1 + mindspore/core/ops/triplet_margin_loss.cc | 20 +- mindspore/core/ops/triplet_margin_loss.h | 10 + .../python/mindspore/ops/operations/nn_ops.py | 2 +- .../st/ops/gpu/test_triplet_margin_loss_op.py | 112 ++++ 10 files changed, 1558 insertions(+), 3 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/triplet_margin_loss_helper.h create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triplet_margin_loss_impl.cu create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triplet_margin_loss_impl.cuh create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/nn/triplet_margin_loss_gpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/nn/triplet_margin_loss_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_triplet_margin_loss_op.py diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/triplet_margin_loss_helper.h b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/triplet_margin_loss_helper.h new file mode 100644 index 00000000000..ff5d72722aa --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/triplet_margin_loss_helper.h @@ -0,0 +1,516 @@ +/** + * Copyright 2019-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_TRIPLET_MARGIN_LOSS_HELPER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_TRIPLET_MARGIN_LOSS_HELPER_H_ +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/triplet_margin_loss_impl.cuh" + +namespace mindspore { +namespace cukernel { +class TripletMarginLossAttr : public GpuKernelAttrBase { + public: + TripletMarginLossAttr() = default; + ~TripletMarginLossAttr() override = default; + int64_t p; + bool swap; + float eps; + std::string reduction; +}; + +template +class TripletMarginLossHelperGpuKernel : public GpuKernelHelperBase { + public: + explicit TripletMarginLossHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) + : GpuKernelHelperBase(kernel_name, device_id) { + shape_size_ = 0; + need_broadcast_ = false; + is_null_input_ = false; + reduction_ = "mean"; + bound_ = 0; + } + + virtual ~TripletMarginLossHelperGpuKernel() = default; + int CalMemSize(const std::vector> &input_shapes, + const std::vector> &output_shapes) override { + constexpr size_t INPUT_NUM_1 = 3; + constexpr size_t OUTPUT_NUM = 1; + constexpr int64_t kzero = 0; + constexpr int64_t kone = 1; + constexpr int64_t ktwo = 2; + constexpr int64_t kthree = 3; + ResetResource(); + dst_shape_.clear(); + std::vector> input_shapes_1; + input_shapes_1.emplace_back(input_shapes[kzero]); + input_shapes_1.emplace_back(input_shapes[kone]); + input_shapes_1.emplace_back(input_shapes[ktwo]); + int inp_flag = + CalShapesSizeInBytes(input_shapes_1, INPUT_NUM_1, kernel_name_, "input_shapes_1", &input_size_list_); + if (inp_flag == -1) { + return inp_flag; + } + input_size_list_.emplace_back(sizeof(M)); + int out_flag = + CalShapesSizeInBytes(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_); + if (out_flag == -1) { + return out_flag; + } + is_null_input_ = (inp_flag == 1 || out_flag == 1); + anchor_shape_ = input_shapes[kzero]; + positive_shape_ = input_shapes[kone]; + negative_shape_ = input_shapes[ktwo]; + size_t dim_x = anchor_shape_.size(); + size_t dim_positive = positive_shape_.size(); + size_t dim_negative = negative_shape_.size(); + shape_size_ = std::max(std::max(dim_x, dim_positive), dim_negative); + std::reverse(anchor_shape_.begin(), anchor_shape_.end()); + std::reverse(positive_shape_.begin(), positive_shape_.end()); + std::reverse(negative_shape_.begin(), negative_shape_.end()); + if (dim_x < shape_size_) anchor_shape_.resize(shape_size_, kone); + if (dim_positive < shape_size_) positive_shape_.resize(shape_size_, kone); + if (dim_negative < shape_size_) negative_shape_.resize(shape_size_, kone); + std::reverse(anchor_shape_.begin(), anchor_shape_.end()); + std::reverse(positive_shape_.begin(), positive_shape_.end()); + std::reverse(negative_shape_.begin(), negative_shape_.end()); + if (anchor_shape_ != positive_shape_ || anchor_shape_ != negative_shape_ || positive_shape_ != negative_shape_) { + need_broadcast_ = true; + } + int64_t tem_shape_size = 0; + for (size_t i = 0; i < shape_size_; i++) { + tem_shape_size++; + dst_shape_.push_back((int64_t)std::max(std::max(anchor_shape_[i], positive_shape_[i]), negative_shape_[i])); + } + std::vector> workspace_shapes_sizet; + std::vector> workspace_shapes_S; + constexpr size_t WORKSPACE_SIZET_NUM = 1; + constexpr size_t WORKSPACE_S_NUM = 1; + std::vector shape_shape; + tem_shape_size *= 4; // store 4 shapes + shape_shape.push_back(tem_shape_size); + workspace_shapes_sizet.emplace_back(shape_shape); + swap_ = attr_ptr_->swap; + std::vector tem_output_shape(dst_shape_); + tem_output_shape.erase(tem_output_shape.begin() + 1); + if (swap_) { + tem_output_shape.insert(tem_output_shape.begin(), kthree); + } else { + tem_output_shape.insert(tem_output_shape.begin(), ktwo); + } + workspace_shapes_S.emplace_back(tem_output_shape); + int work_flag = CalShapesSizeInBytes(workspace_shapes_sizet, WORKSPACE_SIZET_NUM, kernel_name_, + "workspace_shapes", &work_size_list_); + if (work_flag == -1) { + return work_flag; + } + work_flag = + CalShapesSizeInBytes(workspace_shapes_S, WORKSPACE_S_NUM, kernel_name_, "workspace_shapes", &work_size_list_); + if (work_flag == -1) { + return work_flag; + } + size_t workspace_boundlist = kthree * sizeof(size_t); + work_size_list_.emplace_back(workspace_boundlist); + if (need_broadcast_) { + std::vector> workspace_shapes_T; + constexpr size_t WORKSPACE_T_NUM = 1; + workspace_shapes_T.emplace_back(dst_shape_); + workspace_shapes_T[0].insert(workspace_shapes_T[0].begin(), kthree); + work_flag = CalShapesSizeInBytes(workspace_shapes_T, WORKSPACE_T_NUM, kernel_name_, "workspace_shapes", + &work_size_list_); + if (work_flag == -1) { + return work_flag; + } + } + return CheckKernelParam(); + } + + int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, + const std::vector &work_ptrs, void *cuda_stream) override { + const int64_t kzero = 0; + const int64_t kone = 1; + const int64_t ktwo = 2; + const int64_t kthree = 3; + const int64_t kfour = 4; + if (is_null_input_) { + return 0; + } + bound_list_[kzero] = ChooseBound(anchor_shape_[kone], positive_shape_[kone], dst_shape_[kone]); + bound_list_[kone] = ChooseBound(anchor_shape_[kone], negative_shape_[kone], dst_shape_[kone]); + bound_list_[ktwo] = ChooseBound(positive_shape_[kone], negative_shape_[kone], dst_shape_[kone]); + bound_ = dst_shape_[kone]; + + size_t outer_size = dst_shape_[kzero]; + size_t inner_size = 1; + for (size_t i = 2; i < shape_size_; i++) { + inner_size *= dst_shape_[i]; + } + T *anchor_ptr = nullptr; + T *positive_ptr = nullptr; + T *negative_ptr = nullptr; + M *margin_ptr = nullptr; + S *output_ptr = nullptr; + + int flag = GetDeviceAddress(input_ptrs, kzero, kernel_name_, &anchor_ptr); + if (flag != 0) { + return flag; + } + flag = GetDeviceAddress(input_ptrs, kone, kernel_name_, &positive_ptr); + if (flag != 0) { + return flag; + } + flag = GetDeviceAddress(input_ptrs, ktwo, kernel_name_, &negative_ptr); + if (flag != 0) { + return flag; + } + flag = GetDeviceAddress(input_ptrs, kthree, kernel_name_, &margin_ptr); + if (flag != 0) { + return flag; + } + flag = GetDeviceAddress(output_ptrs, kzero, kernel_name_, &output_ptr); + if (flag != 0) { + return flag; + } + std::vector input_shapes; + input_shapes.insert(input_shapes.end(), anchor_shape_.begin(), anchor_shape_.end()); + input_shapes.insert(input_shapes.end(), positive_shape_.begin(), positive_shape_.end()); + input_shapes.insert(input_shapes.end(), negative_shape_.begin(), negative_shape_.end()); + input_shapes.insert(input_shapes.end(), dst_shape_.begin(), dst_shape_.end()); + int64_t *anchor_shape_ptr = nullptr, *dst_shape_ptr = nullptr; + S *tem_output_ptr = nullptr; + size_t *bound_list_ptr = nullptr; + T *anchor_broadcast_ptr = anchor_ptr, *positive_broadcast_ptr = positive_ptr, + *negative_broadcast_ptr = positive_ptr; + flag = GetDeviceAddress(work_ptrs, kzero, kernel_name_, &anchor_shape_ptr); + if (flag != 0) { + return flag; + } + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(anchor_shape_ptr, &input_shapes[kzero], shape_size_ * sizeof(int64_t) * kfour, + cudaMemcpyHostToDevice, reinterpret_cast(cuda_stream)), + "cudaMemcpyAsync workspace failed"); + + dst_shape_ptr = anchor_shape_ptr + kthree * shape_size_; + + flag = GetDeviceAddress(work_ptrs, kone, kernel_name_, &tem_output_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(work_ptrs, ktwo, kernel_name_, &bound_list_ptr); + if (flag != 0) { + return flag; + } + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(bound_list_ptr, &bound_list_[kzero], sizeof(size_t) * kthree, cudaMemcpyHostToDevice, + reinterpret_cast(cuda_stream)), + "cudaMemcpyAsync workspace failed"); + + if (need_broadcast_) { + flag = GetDeviceAddress(work_ptrs, kthree, kernel_name_, &anchor_broadcast_ptr); + if (flag != 0) { + return flag; + } + positive_broadcast_ptr = anchor_broadcast_ptr + bound_ * outer_size * inner_size; + negative_broadcast_ptr = positive_broadcast_ptr + bound_ * outer_size * inner_size; + } + CalTripletMarginLoss(anchor_ptr, positive_ptr, negative_ptr, anchor_broadcast_ptr, positive_broadcast_ptr, + negative_broadcast_ptr, output_ptr, tem_output_ptr, anchor_shape_ptr, dst_shape_ptr, + outer_size, inner_size, bound_list_ptr, bound_, shape_size_, margin_ptr, attr_ptr_->p, + attr_ptr_->eps, reduction_, swap_, need_broadcast_, device_id_, + reinterpret_cast(cuda_stream)); + return 0; + } + + void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override { + attr_ptr_ = std::dynamic_pointer_cast(kernel_attr); + } + + protected: + int CheckKernelParam() override { + std::string reduction_list = "[mean,none,sum]"; + reduction_ = attr_ptr_->reduction; + if (reduction_ != "mean" && reduction_ != "none" && reduction_ != "sum") { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'reduciton' should be in " << reduction_list + << "but got:" << reduction_; + return -1; + } + return 0; + } + + size_t ChooseBound(size_t src_bound_first, size_t src_bound_second, size_t dst_bound) { + if (src_bound_first == 1 && src_bound_second == 1 && dst_bound != 1) { + return 1; + } + return dst_bound; + } + + private: + std::shared_ptr attr_ptr_; + std::vector anchor_shape_, positive_shape_, negative_shape_, dst_shape_; + size_t shape_size_; + size_t bound_list_[3]; + size_t bound_; + bool need_broadcast_; + bool swap_; + bool is_null_input_; + std::string reduction_; +}; + +// half +template +class TripletMarginLossHelperGpuKernel : public GpuKernelHelperBase { + public: + explicit TripletMarginLossHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) + : GpuKernelHelperBase(kernel_name, device_id) { + shape_size_ = 0; + need_broadcast_ = false; + is_null_input_ = false; + reduction_ = "mean"; + bound_ = 0; + } + + virtual ~TripletMarginLossHelperGpuKernel() = default; + int CalMemSize(const std::vector> &input_shapes, + const std::vector> &output_shapes) override { + constexpr size_t INPUT_NUM_1 = 3; + constexpr size_t OUTPUT_NUM = 1; + constexpr int64_t kzero = 0; + constexpr int64_t kone = 1; + constexpr int64_t ktwo = 2; + constexpr int64_t kthree = 3; + ResetResource(); + dst_shape_.clear(); + std::vector> input_shapes_1; + input_shapes_1.emplace_back(input_shapes[kzero]); + input_shapes_1.emplace_back(input_shapes[kone]); + input_shapes_1.emplace_back(input_shapes[ktwo]); + int inp_flag = + CalShapesSizeInBytes(input_shapes_1, INPUT_NUM_1, kernel_name_, "input_shapes_1", &input_size_list_); + if (inp_flag == -1) { + return inp_flag; + } + input_size_list_.emplace_back(ktwo * sizeof(M)); + int out_flag = + CalShapesSizeInBytes(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_); + if (out_flag == -1) { + return out_flag; + } + is_null_input_ = (inp_flag == 1 || out_flag == 1); + anchor_shape_ = input_shapes[kzero]; + positive_shape_ = input_shapes[kone]; + negative_shape_ = input_shapes[ktwo]; + size_t dim_x = anchor_shape_.size(); + size_t dim_positive = positive_shape_.size(); + size_t dim_negative = negative_shape_.size(); + shape_size_ = std::max(std::max(dim_x, dim_positive), dim_negative); + std::reverse(anchor_shape_.begin(), anchor_shape_.end()); + std::reverse(positive_shape_.begin(), positive_shape_.end()); + std::reverse(negative_shape_.begin(), negative_shape_.end()); + if (dim_x < shape_size_) anchor_shape_.resize(shape_size_, kone); + if (dim_positive < shape_size_) positive_shape_.resize(shape_size_, kone); + if (dim_negative < shape_size_) negative_shape_.resize(shape_size_, kone); + std::reverse(anchor_shape_.begin(), anchor_shape_.end()); + std::reverse(positive_shape_.begin(), positive_shape_.end()); + std::reverse(negative_shape_.begin(), negative_shape_.end()); + if (anchor_shape_ != positive_shape_ || anchor_shape_ != negative_shape_ || positive_shape_ != negative_shape_) { + need_broadcast_ = true; + } + int64_t tem_shape_size = 0; + for (size_t i = 0; i < shape_size_; i++) { + tem_shape_size++; + dst_shape_.push_back((int64_t)std::max(std::max(anchor_shape_[i], positive_shape_[i]), negative_shape_[i])); + } + std::vector> workspace_shapes_sizet; + std::vector> workspace_shapes_S; + constexpr size_t WORKSPACE_SIZET_NUM = 1; + constexpr size_t WORKSPACE_S_NUM = 1; + std::vector shape_shape; + tem_shape_size *= 4; // store 4 shapes + shape_shape.push_back(tem_shape_size); + workspace_shapes_sizet.emplace_back(shape_shape); + swap_ = attr_ptr_->swap; + std::vector tem_output_shape(dst_shape_); + tem_output_shape.erase(tem_output_shape.begin() + 1); + if (swap_) { + tem_output_shape.insert(tem_output_shape.begin(), kthree); + } else { + tem_output_shape.insert(tem_output_shape.begin(), ktwo); + } + workspace_shapes_S.emplace_back(tem_output_shape); + int work_flag = CalShapesSizeInBytes(workspace_shapes_sizet, WORKSPACE_SIZET_NUM, kernel_name_, + "workspace_shapes", &work_size_list_); + if (work_flag == -1) { + return work_flag; + } + work_flag = CalShapesSizeInBytes(workspace_shapes_S, WORKSPACE_S_NUM, kernel_name_, "workspace_shapes", + &work_size_list_); + if (work_flag == -1) { + return work_flag; + } + size_t workspace_boundlist = kthree * sizeof(size_t); + work_size_list_.emplace_back(workspace_boundlist); + if (need_broadcast_) { + std::vector> workspace_shapes_T; + constexpr size_t WORKSPACE_T_NUM = 1; + workspace_shapes_T.emplace_back(dst_shape_); + workspace_shapes_T[0].insert(workspace_shapes_T[0].begin(), kthree); + work_flag = CalShapesSizeInBytes(workspace_shapes_T, WORKSPACE_T_NUM, kernel_name_, "workspace_shapes", + &work_size_list_); + if (work_flag == -1) { + return work_flag; + } + } + return CheckKernelParam(); + } + + int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, + const std::vector &work_ptrs, void *cuda_stream) override { + const int64_t kzero = 0; + const int64_t kone = 1; + const int64_t ktwo = 2; + const int64_t kthree = 3; + const int64_t kfour = 4; + if (is_null_input_) { + return 0; + } + bound_list_[kzero] = ChooseBound(anchor_shape_[kone], positive_shape_[kone], dst_shape_[kone]); + bound_list_[kone] = ChooseBound(anchor_shape_[kone], negative_shape_[kone], dst_shape_[kone]); + bound_list_[ktwo] = ChooseBound(positive_shape_[kone], negative_shape_[kone], dst_shape_[kone]); + bound_ = dst_shape_[kone]; + + size_t outer_size = dst_shape_[kzero]; + size_t inner_size = 1; + for (size_t i = 2; i < shape_size_; i++) { + inner_size *= dst_shape_[i]; + } + half *anchor_ptr = nullptr; + half *positive_ptr = nullptr; + half *negative_ptr = nullptr; + M *margin_ptr = nullptr; + half *output_ptr = nullptr; + + int flag = GetDeviceAddress(input_ptrs, kzero, kernel_name_, &anchor_ptr); + if (flag != 0) { + return flag; + } + flag = GetDeviceAddress(input_ptrs, kone, kernel_name_, &positive_ptr); + if (flag != 0) { + return flag; + } + flag = GetDeviceAddress(input_ptrs, ktwo, kernel_name_, &negative_ptr); + if (flag != 0) { + return flag; + } + flag = GetDeviceAddress(input_ptrs, kthree, kernel_name_, &margin_ptr); + if (flag != 0) { + return flag; + } + flag = GetDeviceAddress(output_ptrs, kzero, kernel_name_, &output_ptr); + if (flag != 0) { + return flag; + } + std::vector input_shapes; + input_shapes.insert(input_shapes.end(), anchor_shape_.begin(), anchor_shape_.end()); + input_shapes.insert(input_shapes.end(), positive_shape_.begin(), positive_shape_.end()); + input_shapes.insert(input_shapes.end(), negative_shape_.begin(), negative_shape_.end()); + input_shapes.insert(input_shapes.end(), dst_shape_.begin(), dst_shape_.end()); + int64_t *anchor_shape_ptr = nullptr, *dst_shape_ptr = nullptr; + float *tem_output_ptr = nullptr; + size_t *bound_list_ptr = nullptr; + half *anchor_broadcast_ptr = anchor_ptr, *positive_broadcast_ptr = positive_ptr, + *negative_broadcast_ptr = positive_ptr; + flag = GetDeviceAddress(work_ptrs, kzero, kernel_name_, &anchor_shape_ptr); + if (flag != 0) { + return flag; + } + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(anchor_shape_ptr, &input_shapes[kzero], shape_size_ * sizeof(int64_t) * kfour, + cudaMemcpyHostToDevice, reinterpret_cast(cuda_stream)), + "cudaMemcpyAsync workspace failed"); + + dst_shape_ptr = anchor_shape_ptr + kthree * shape_size_; + + flag = GetDeviceAddress(work_ptrs, kone, kernel_name_, &tem_output_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(work_ptrs, ktwo, kernel_name_, &bound_list_ptr); + if (flag != 0) { + return flag; + } + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(bound_list_ptr, &bound_list_[kzero], sizeof(size_t) * kthree, cudaMemcpyHostToDevice, + reinterpret_cast(cuda_stream)), + "cudaMemcpyAsync workspace failed"); + + if (need_broadcast_) { + flag = GetDeviceAddress(work_ptrs, kthree, kernel_name_, &anchor_broadcast_ptr); + if (flag != 0) { + return flag; + } + positive_broadcast_ptr = anchor_broadcast_ptr + bound_ * outer_size * inner_size; + negative_broadcast_ptr = positive_broadcast_ptr + bound_ * outer_size * inner_size; + } + CalTripletMarginLoss(anchor_ptr, positive_ptr, negative_ptr, anchor_broadcast_ptr, positive_broadcast_ptr, + negative_broadcast_ptr, output_ptr, tem_output_ptr, anchor_shape_ptr, dst_shape_ptr, + outer_size, inner_size, bound_list_ptr, bound_, shape_size_, margin_ptr, attr_ptr_->p, + attr_ptr_->eps, reduction_, swap_, need_broadcast_, device_id_, + reinterpret_cast(cuda_stream)); + return 0; + } + + void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override { + attr_ptr_ = std::dynamic_pointer_cast(kernel_attr); + } + + protected: + int CheckKernelParam() override { + std::string reduction_list = "[mean,none,sum]"; + reduction_ = attr_ptr_->reduction; + if (reduction_ != "mean" && reduction_ != "none" && reduction_ != "sum") { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'reduciton' should be in " << reduction_list + << "but got:" << reduction_; + return -1; + } + return 0; + } + + size_t ChooseBound(size_t src_bound_first, size_t src_bound_second, size_t dst_bound) { + if (src_bound_first == 1 && src_bound_second == 1 && dst_bound != 1) { + return 1; + } + return dst_bound; + } + + private: + std::shared_ptr attr_ptr_; + std::vector anchor_shape_, positive_shape_, negative_shape_, dst_shape_; + size_t shape_size_; + size_t bound_list_[3]; + size_t bound_; + bool need_broadcast_; + bool swap_; + bool is_null_input_; + std::string reduction_; +}; +} // namespace cukernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_TRIPLET_MARGIN_LOSS_HELPER_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triplet_margin_loss_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triplet_margin_loss_impl.cu new file mode 100644 index 00000000000..b058cebd12f --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triplet_margin_loss_impl.cu @@ -0,0 +1,508 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "triplet_margin_loss_impl.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" + +__device__ __forceinline__ int64_t Index(const int64_t &index, const int64_t &dim) { return dim == 1 ? 0 : index; } + + +template +__global__ void FillAndBroadcast(const int64_t size, const size_t shape_size, + const int64_t *tensor_shapes, const int64_t *dst_shape, + const T *anchor, const T *positive, const T *negative, + T *anchor_broadcast) { + const T *pair_tensor[3] = {anchor, positive, negative}; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < 3*size; pos += blockDim.x * gridDim.x) { + const size_t mode = pos/size; + const int64_t *src_shape = tensor_shapes + shape_size * mode; + size_t tmp_pos = pos % size; + size_t pos_size = size / dst_shape[0]; + size_t dst_index_array[8]; + dst_index_array[0] = tmp_pos / pos_size; + for (size_t i = 1; i < shape_size; i++) { + tmp_pos -= dst_index_array[i - 1] * pos_size; + pos_size = pos_size / dst_shape[i]; + dst_index_array[i] = tmp_pos / pos_size; + } + size_t src_size = 1; + for (size_t i = 0; i < shape_size; i++) { + src_size *= src_shape[i]; + } + size_t src_pos = 0; + for (size_t i = 0; i < shape_size; i++) { + src_size /= src_shape[i]; + size_t length_by_index = Index(dst_index_array[i], src_shape[i]) * src_size; + src_pos += length_by_index; + } + (anchor_broadcast + mode * size)[pos % size] = pair_tensor[mode][src_pos]; + } + return; +} + +template +__global__ void PairwiseDistance(const T *anchor, const T *positive, const T *negative, + const size_t *bound_list, const size_t bound, const size_t outer_size, + const size_t inner_size, S *tem_output, const size_t n, const int64_t p, + const float eps) { + const T *pair_tensor[3] = {anchor, positive, negative}; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; + pos < n * outer_size * inner_size; pos += gridDim.x * blockDim.x) { + size_t mode = pos / (outer_size * inner_size); + size_t idx = pos % (outer_size * inner_size); + S res = 0; + size_t x = idx / inner_size % outer_size; + size_t y = idx % inner_size; + for (int i = 0; i < bound_list[mode]; i++) { + size_t input_offset = x * bound * inner_size + i * inner_size + y; + S base = + abs(static_cast(pair_tensor[mode / 2][input_offset] - pair_tensor[(mode + 3) / 2][input_offset]) + eps); + S tem = pow(base, static_cast(p)); + res += tem; + } + tem_output[pos] = pow(res, static_cast(1.0 / p)); + } + return; +} + +template +__global__ void PairwiseDistancePzero(const size_t *bound_list, const size_t output_size, + S *tem_output, const size_t n) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < n * output_size; pos += gridDim.x * blockDim.x) { + size_t mode = pos / output_size; + tem_output[pos] = static_cast(bound_list[mode]); + } + return; +} + + +template +__global__ void SwapTrue(S *tem_output, const size_t output_size) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size; pos += gridDim.x * blockDim.x) { + tem_output[pos + output_size] = tem_output[pos + output_size] > tem_output[pos + 2 * output_size] ? + tem_output[pos + 2 * output_size] : tem_output[pos + output_size]; + } + return; +} + +template +__global__ void MaxReduction(S *tem_output, S *output, const size_t output_size, const M *margin) { + S lower_bound = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size; pos += gridDim.x * blockDim.x) { + output[pos] = max(static_cast(margin[0]) + tem_output[pos] - tem_output[pos + output_size], lower_bound); + } + return; +} + + +template +__global__ void AddTile(S *tmp_loss, size_t index) { + tmp_loss[0] += tmp_loss[index]; +} + +template +__global__ void PartialSum(S *tmp_loss, size_t stride) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < stride; i += blockDim.x * gridDim.x) { + tmp_loss[i] += tmp_loss[i + stride]; + } +} + +template +__global__ void ReductionDivde(S *output, S *tem_output, const size_t k) { + output[0] = tem_output[0] / k; +} + +// half +template <> +__global__ void PairwiseDistance(const half *anchor, const half *positive, const half *negative, + const size_t *bound_list, const size_t bound, const size_t outer_size, + const size_t inner_size, float *tem_output, const size_t n, const int64_t p, + const float eps) { + const half *pair_tensor[3] = {anchor, positive, negative}; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; + pos < n * outer_size * inner_size; pos += gridDim.x * blockDim.x) { + size_t mode = pos / (outer_size * inner_size); + size_t idx = pos % (outer_size * inner_size); + float res = 0; + size_t x = idx / inner_size % outer_size; + size_t y = idx % inner_size; + for (int i = 0; i < bound_list[mode]; i++) { + size_t input_offset = x * bound * inner_size + i * inner_size + y; + float base = abs(__half2float(pair_tensor[mode / 2][input_offset]) - + __half2float(pair_tensor[(mode+3) / 2][input_offset]) + eps); + float tem = pow(base, p); + res += tem; + } + tem_output[pos] = pow(res, 1.0 / p); + } + return; +} + + +// half +template +__global__ void MaxReduction(float *tem_output, half *output, const size_t output_size, const M *margin) { + float lower_bound = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size; pos += gridDim.x * blockDim.x) { + output[pos] = __float2half(max(margin[0] + tem_output[pos] - tem_output[pos + output_size], lower_bound)); + } + return; +} + + +// half +__global__ void ReductionDivde(half *output, float *tem_output, const size_t k) { + output[0] = __float2half((tem_output[0] / k)); +} + +// Complex +template +__global__ void PairwiseDistance(const Complex *anchor, const Complex *positive, + const Complex *negative, const size_t *bound_list, + const size_t bound, const size_t outer_size, const size_t inner_size, + S *tem_output, const size_t n, const int64_t p, const float eps) { + const Complex *pair_tensor[3] = {anchor, positive, negative}; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; + pos < n * outer_size * inner_size; pos += gridDim.x * blockDim.x) { + size_t mode = pos / (outer_size * inner_size); + size_t idx = pos % (outer_size * inner_size); + S res = 0; + size_t x = idx / inner_size % outer_size; + size_t y = idx % inner_size; + for (int i = 0; i < bound_list[mode]; i++) { + size_t input_offset = x * bound * inner_size + i * inner_size + y; + Complex base_complex = + pair_tensor[mode / 2][input_offset] - pair_tensor[(mode + 3) / 2][input_offset] + static_cast(eps); + S base = sqrt((base_complex.real() * base_complex.real() + base_complex.imag() * base_complex.imag())); + S tem = pow(base, static_cast(p)); + res += tem; + } + tem_output[pos] = pow(res, 1.0 / p); + } + return; +} + +template +void CalTripletMarginLoss(const T *anchor, const T *positive, const T *negative, T *anchor_broadcast, + T *positive_broadcast, T *negative_broadcast, S *output, H *tem_output, + const int64_t *tensor_shapes, const int64_t *dst_shape, const size_t outer_size, + const size_t inner_size, const size_t *bound_list, const size_t bound, + const size_t shape_size, M *margin, const int64_t p, const float eps, + const std::string reduction, const bool swap, const bool need_broadcast, + const uint32_t &device_id, cudaStream_t cuda_stream) { + const int64_t size = outer_size * inner_size * bound; + size_t n; + if (swap) + n = 3; + else + n = 2; + const size_t output_size = outer_size * inner_size; + if (p == 0) { + PairwiseDistancePzero<<>> + (bound_list, output_size, tem_output, n); + } else if (need_broadcast) { + FillAndBroadcast<<>>( + size, shape_size, tensor_shapes, dst_shape, anchor, positive, negative, + anchor_broadcast); + PairwiseDistance<<>>( + anchor_broadcast, positive_broadcast, negative_broadcast, bound_list, bound, outer_size, + inner_size, tem_output, n, p, eps); + } else { + PairwiseDistance<<>>( + anchor, positive, negative, bound_list, bound, outer_size, inner_size, tem_output, n, p, eps); + } + + if (swap) { + SwapTrue<<>>(tem_output, + output_size); + } + if (reduction == "none") { + MaxReduction<<>>(tem_output, + output, output_size, margin); + } else { + MaxReduction<<>>(tem_output, + tem_output, output_size, margin); + if (output_size % 2 == 1 && output_size != 1) { + AddTile<<<1, 1, 0, cuda_stream>>>(tem_output, output_size - 1); + } + for (size_t stride = output_size / 2; stride > 0; stride >>= 1) { + PartialSum<<>>(tem_output, stride); + if (stride > 2 && stride % 2 == 1) { + AddTile<<<1, 1, 0, cuda_stream>>>(tem_output, stride - 1); + } + } + if (reduction == "mean") { + ReductionDivde<<<1, 1, 0, cuda_stream>>>(output, tem_output, output_size); + } else { + ReductionDivde<<<1, 1, 0, cuda_stream>>>(output, tem_output, 1); + } + } + + return; +} + + +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const int8_t *anchor, const int8_t *positive, const int8_t *negative, + int8_t *anchor_broadcast, int8_t *positive_broadcast, int8_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const int16_t *anchor, const int16_t *positive, const int16_t *negative, + int16_t *anchor_broadcast, int16_t *positive_broadcast, int16_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const int32_t *anchor, const int32_t *positive, const int32_t *negative, + int32_t *anchor_broadcast, int32_t *positive_broadcast, int32_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const int64_t *anchor, const int64_t *positive, const int64_t *negative, + int64_t *anchor_broadcast, int64_t *positive_broadcast, int64_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const uint8_t *anchor, const uint8_t *positive, const uint8_t *negative, + uint8_t *anchor_broadcast, uint8_t *positive_broadcast, uint8_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const uint16_t *anchor, const uint16_t *positive, const uint16_t *negative, + uint16_t *anchor_broadcast, uint16_t *positive_broadcast, uint16_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const uint32_t *anchor, const uint32_t *positive, const uint32_t *negative, + uint32_t *anchor_broadcast, uint32_t *positive_broadcast, uint32_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const uint64_t *anchor, const uint64_t *positive, const uint64_t *negative, + uint64_t *anchor_broadcast, uint64_t *positive_broadcast, uint64_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const double *anchor, const double *positive, const double *negative, + double *anchor_broadcast, double *positive_broadcast, double *negative_broadcast, + double *output, double *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const float *anchor, const float *positive, const float *negative, + float *anchor_broadcast, float *positive_broadcast, float *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const half *anchor, const half *positive, const half *negative, + half *anchor_broadcast, half *positive_broadcast, half *negative_broadcast, + half *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void +CalTripletMarginLoss, float, float, float>( + const Complex *anchor, const Complex *positive, const Complex *negative, + Complex *anchor_broadcast, Complex *positive_broadcast, Complex *negative_broadcast, + float *output, float *tem_output, + const int64_t *tensor_shapes, const int64_t *dst_shape, + const size_t outer_size, const size_t inner_size, const size_t *bound_list, const size_t bound, + const size_t shape_size, float *margin, const int64_t p, + const float eps, const std::string reduction, const bool swap, + const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void +CalTripletMarginLoss, float, double, double>( + const Complex *anchor, const Complex *positive, + const Complex *negative, + Complex *anchor_broadcast, Complex *positive_broadcast, Complex *negative_broadcast, + double *output, double *tem_output, + const int64_t *tensor_shapes, const int64_t *dst_shape, + const size_t outer_size, const size_t inner_size, const size_t *bound_list, const size_t bound, + const size_t shape_size, float *margin, const int64_t p, + const float eps, const std::string reduction, const bool swap, + const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const int8_t *anchor, const int8_t *positive, const int8_t *negative, + int8_t *anchor_broadcast, int8_t *positive_broadcast, int8_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const int16_t *anchor, const int16_t *positive, const int16_t *negative, + int16_t *anchor_broadcast, int16_t *positive_broadcast, int16_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const int32_t *anchor, const int32_t *positive, const int32_t *negative, + int32_t *anchor_broadcast, int32_t *positive_broadcast, int32_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const int64_t *anchor, const int64_t *positive, const int64_t *negative, + int64_t *anchor_broadcast, int64_t *positive_broadcast, int64_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const uint8_t *anchor, const uint8_t *positive, const uint8_t *negative, + uint8_t *anchor_broadcast, uint8_t *positive_broadcast, uint8_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const uint16_t *anchor, const uint16_t *positive, const uint16_t *negative, + uint16_t *anchor_broadcast, uint16_t *positive_broadcast, uint16_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const uint32_t *anchor, const uint32_t *positive, const uint32_t *negative, + uint32_t *anchor_broadcast, uint32_t *positive_broadcast, uint32_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const uint64_t *anchor, const uint64_t *positive, const uint64_t *negative, + uint64_t *anchor_broadcast, uint64_t *positive_broadcast, uint64_t *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const double *anchor, const double *positive, const double *negative, + double *anchor_broadcast, double *positive_broadcast, double *negative_broadcast, + double *output, double *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const float *anchor, const float *positive, const float *negative, + float *anchor_broadcast, float *positive_broadcast, float *negative_broadcast, + float *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void CalTripletMarginLoss( + const half *anchor, const half *positive, const half *negative, + half *anchor_broadcast, half *positive_broadcast, half *negative_broadcast, + half *output, float *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, + const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin, + const int64_t p, const float eps, const std::string reduction, + const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void +CalTripletMarginLoss, double, float, float>( + const Complex *anchor, const Complex *positive, + const Complex *negative, Complex *anchor_broadcast, Complex *positive_broadcast, + Complex *negative_broadcast, float *output, float *tem_output, + const int64_t *tensor_shapes, const int64_t *dst_shape, + const size_t outer_size, const size_t inner_size, const size_t *bound_list, const size_t bound, + const size_t shape_size, double *margin, const int64_t p, + const float eps, const std::string reduction, const bool swap, + const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void +CalTripletMarginLoss, double, double, double>( + const Complex *anchor, const Complex *positive, + const Complex *negative, Complex *anchor_broadcast, Complex *positive_broadcast, + Complex *negative_broadcast, double *output, double *tem_output, + const int64_t *tensor_shapes, const int64_t *dst_shape, + const size_t outer_size, const size_t inner_size, const size_t *bound_list, const size_t bound, + const size_t shape_size, double *margin, const int64_t p, + const float eps, const std::string reduction, const bool swap, + const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); + diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triplet_margin_loss_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triplet_margin_loss_impl.cuh new file mode 100644 index 00000000000..fa83fc53031 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/triplet_margin_loss_impl.cuh @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIPLET_MARGIN_LOSS_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIPLET_MARGIN_LOSS_IMPL_CUH_ +#include +#include +#include "include/cuda_fp16.h" +#include "mindapi/base/types.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +template using Complex = mindspore::utils::Complex; + +template +CUDA_LIB_EXPORT void CalTripletMarginLoss(const T *anchor, const T *positive, const T *negative, T *anchor_broadcast, + T *positive_broadcast, T *negative_broadcast, S *output, H *tem_output, const int64_t *tensor_shapes, + const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, const size_t *bound_list, + const size_t bound, const size_t shape_size, M *margin, const int64_t p, const float eps, + const std::string reduction, const bool swap, const bool need_broadcast, const uint32_t &device_id, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIPLET_MARGIN_LOSS_IMPL_CUH_ + + diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/triplet_margin_loss_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/triplet_margin_loss_gpu_kernel.cc new file mode 100644 index 00000000000..0f0e4c96179 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/triplet_margin_loss_gpu_kernel.cc @@ -0,0 +1,294 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/nn/triplet_margin_loss_gpu_kernel.h" +#include +namespace mindspore { +namespace kernel { +namespace { +template +std::unique_ptr CreateTripletMarginLossKernelPtr(const std::string &kernel_name, + const uint32_t &device_id) { + return std::make_unique>(kernel_name, device_id); +} + +using TripletMarginLossPtrCreatorFunc = + std::function(const std::string &, const uint32_t &)>; + +const std::vector> kernel_attr = { + {KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr, float, float>}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat64), + CreateTripletMarginLossKernelPtr, float, double>}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat64), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr, double, float>}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + CreateTripletMarginLossKernelPtr, double, double>}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat16), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat32), + CreateTripletMarginLossKernelPtr}}; +} // namespace + +bool TripletMarginLossGpuKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + std::vector input_ptrs = ConvertPtrs(inputs); + std::vector work_ptrs = ConvertPtrs(workspace); + std::vector output_ptrs = ConvertPtrs(outputs); + if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) { + return false; + } + return true; +} + +bool TripletMarginLossGpuKernelMod::Init(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs) { + auto kernel_ptr = std::dynamic_pointer_cast(base_operator); + kernel_name_ = kernel_ptr->name(); + auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport()); + if (!is_match) { + return false; + } + attr_ptr_->p = kernel_ptr->get_p(); + attr_ptr_->swap = kernel_ptr->get_swap(); + attr_ptr_->eps = kernel_ptr->get_eps(); + attr_ptr_->reduction = kernel_ptr->get_reduction(); + helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_)); + helper_ptr_->SetKernelParam(attr_ptr_); + + Resize(base_operator, inputs, outputs); + return true; +} + +int TripletMarginLossGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost) { + for (const auto &input : inputs) { + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + constexpr int64_t kzero = 0; + constexpr int64_t kone = 1; + constexpr int64_t ktwo = 2; + constexpr int64_t kthree = 3; + std::vector> input_shapes; + std::vector> output_shapes; + + std::vector inp_shape1 = inputs[kzero]->GetShapeVector(); + std::vector inp_shape2 = inputs[kone]->GetShapeVector(); + std::vector inp_shape3 = inputs[ktwo]->GetShapeVector(); + std::vector inp_shape4 = inputs[kthree]->GetShapeVector(); + std::vector out_shape = outputs[kzero]->GetShapeVector(); + input_shapes.emplace_back(inp_shape1); + input_shapes.emplace_back(inp_shape2); + input_shapes.emplace_back(inp_shape3); + input_shapes.emplace_back(inp_shape4); + output_shapes.emplace_back(out_shape); + if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) { + return KRET_RESIZE_FAILED; + } + input_size_list_ = helper_ptr_->GetInputSizeList(); + output_size_list_ = helper_ptr_->GetOutputSizeList(); + workspace_size_list_ = helper_ptr_->GetWorkSizeList(); + return KRET_OK; +} + +std::vector TripletMarginLossGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list), + [](const std::pair &item) { return item.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, TripletMarginLoss, TripletMarginLossGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/triplet_margin_loss_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/triplet_margin_loss_gpu_kernel.h new file mode 100644 index 00000000000..061489f8bbc --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/triplet_margin_loss_gpu_kernel.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_TRIPLET_MARGIN_LOSS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_TRIPLET_MARGIN_LOSS_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/triplet_margin_loss.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/triplet_margin_loss_helper.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +#include "kernel/kernel.h" +namespace mindspore { +namespace kernel { +template +using Complex = mindspore::utils::Complex; + +class TripletMarginLossGpuKernelMod : public NativeGpuKernelMod { + public: + TripletMarginLossGpuKernelMod() { attr_ptr_ = std::make_shared(); } + ~TripletMarginLossGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + int Resize( + const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost = std::map()) override; + std::vector GetOpSupport() override; + + private: + std::unique_ptr helper_ptr_{nullptr}; + std::shared_ptr attr_ptr_{nullptr}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_TRIPLET_MARGIN_LOSS_GPU_KERNEL_H_ diff --git a/mindspore/core/ops/op_name.h b/mindspore/core/ops/op_name.h index dc6df2a31da..9d21671f3cb 100644 --- a/mindspore/core/ops/op_name.h +++ b/mindspore/core/ops/op_name.h @@ -270,6 +270,7 @@ constexpr auto kSideEffectIO = "side_effect_io"; constexpr auto kDeviceType = "device_type"; constexpr auto kExclusive = "exclusive"; constexpr auto kReverse = "reverse"; +constexpr auto kSwap = "swap"; constexpr auto kSplitStride = "split_stride"; constexpr auto kExtendTop = "extend_top"; constexpr auto kExtendBottom = "extend_bottom"; diff --git a/mindspore/core/ops/triplet_margin_loss.cc b/mindspore/core/ops/triplet_margin_loss.cc index da9394f05fa..b6e807ce1cb 100644 --- a/mindspore/core/ops/triplet_margin_loss.cc +++ b/mindspore/core/ops/triplet_margin_loss.cc @@ -77,7 +77,7 @@ TypePtr TripletMarginLossInferType(const PrimitivePtr &primitive, const std::vec auto op_name = primitive->name(); const std::set valid_types = {kComplex64, kComplex128, kFloat64, kFloat32, kFloat16, kInt16, kInt32, kInt64, kInt8, kUInt16, kUInt32, kUInt64, kUInt8}; - const std::set valid_types2 = {kFloat32}; + const std::set valid_types2 = {kFloat32, kFloat64}; std::map types; types.emplace("x", input_args[kInputIndex0]->BuildType()); types.emplace("positive", input_args[kInputIndex1]->BuildType()); @@ -87,8 +87,14 @@ TypePtr TripletMarginLossInferType(const PrimitivePtr &primitive, const std::vec (void)CheckAndConvertUtils::CheckTensorTypeValid("margin", margin, valid_types2, op_name); auto x_type = input_args[kInputIndex0]->BuildType(); TypePtr output; - if (x_type == kFloat16) { + if (x_type->isa()) { + auto tensor_type = x_type->cast(); + x_type = tensor_type->element(); + } + if (IsIdentidityOrSubclass(x_type, kFloat16)) { output = kFloat16; + } else if (IsIdentidityOrSubclass(x_type, kFloat64) || IsIdentidityOrSubclass(x_type, kComplex128)) { + output = kFloat64; } else { output = kFloat32; } @@ -97,6 +103,16 @@ TypePtr TripletMarginLossInferType(const PrimitivePtr &primitive, const std::vec } // namespace MIND_API_OPERATOR_IMPL(TripletMarginLoss, BaseOperator); +void TripletMarginLoss::set_p(const int64_t p) { (void)this->AddAttr(kP, api::MakeValue(p)); } +void TripletMarginLoss::set_eps(const float eps) { (void)this->AddAttr(kEps, api::MakeValue(eps)); } +void TripletMarginLoss::set_swap(const bool swap) { (void)this->AddAttr(kSwap, api::MakeValue(swap)); } +void TripletMarginLoss::set_reduction(const std::string &reduction) { + (void)this->AddAttr(kReduction, api::MakeValue(reduction)); +} +int64_t TripletMarginLoss::get_p() const { return GetValue(GetAttr(kP)); } +float TripletMarginLoss::get_eps() const { return GetValue(GetAttr(kEps)); } +bool TripletMarginLoss::get_swap() const { return GetValue(GetAttr(kSwap)); } +std::string TripletMarginLoss::get_reduction() const { return GetValue(GetAttr(kReduction)); } AbstractBasePtr TripletMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); diff --git a/mindspore/core/ops/triplet_margin_loss.h b/mindspore/core/ops/triplet_margin_loss.h index 14bde96830f..0513dc5d3ba 100644 --- a/mindspore/core/ops/triplet_margin_loss.h +++ b/mindspore/core/ops/triplet_margin_loss.h @@ -34,6 +34,16 @@ class MIND_API TripletMarginLoss : public BaseOperator { TripletMarginLoss() : BaseOperator(kNameTripletMarginLoss) { InitIOName({"x", "positive", "negative", "margin"}, {"y"}); } + void Init(const int64_t p = 2, const float eps = 1e-6, const bool swap = false, + const std::string &reduction = "mean"); + void set_p(const int64_t p); + void set_eps(const float eps); + void set_swap(const bool swap); + void set_reduction(const std::string &reduction); + int64_t get_p() const; + float get_eps() const; + bool get_swap() const; + std::string get_reduction() const; }; abstract::AbstractBasePtr TripletMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index 9df5bc8c5ec..8162ab4dd41 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -10058,7 +10058,7 @@ class TripletMarginLoss(Primitive): ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. Supported Platforms: - ``Ascend`` ``CPU`` + ``Ascend`` ``CPU`` ``GPU`` Examples: >>> loss = ops.TripletMarginLoss() diff --git a/tests/st/ops/gpu/test_triplet_margin_loss_op.py b/tests/st/ops/gpu/test_triplet_margin_loss_op.py new file mode 100644 index 00000000000..16a150d2481 --- /dev/null +++ b/tests/st/ops/gpu/test_triplet_margin_loss_op.py @@ -0,0 +1,112 @@ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.ops.operations.nn_ops as ops +import torch + + +class NetTripletMarginLoss(nn.Cell): + + def __init__(self, p=2, swap=False, eps=1e-6, reduction="mean"): + super(NetTripletMarginLoss, self).__init__() + self.triplet_margin_loss = ops.TripletMarginLoss( + p=p, swap=swap, eps=eps, reduction=reduction) + + def construct(self, anchor, positive, negative, margin): + return self.triplet_margin_loss(anchor, positive, negative, margin) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_triplet_margin_loss_float64(): + """ + Feature: Input type of float64 + Description: Input type of [float64, float64, float64, float64]. + Expectation: success. + """ + for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: + context.set_context(mode=mode, device_target="GPU") + data_type = np.float64 + anchor_array = np.array([[1.3, 20.5, 5.6]]).astype(data_type) + positive_array = np.array([[2., 10., 1.], + [6., 7., 10.], + [13., 4., 1.], + [0.33, -4, -1.5]]).astype(data_type) + negative_array = np.array([[2., 21., 6.], + [68., 9., 10.], + [131., 25., 16.], + [0.31, -0.14, -16.]]).astype(data_type) + margin = np.float32(2.0) + p = 0 + swap = True + reduction = "none" + eps = 1e-5 + + anchor = Tensor(anchor_array) + positive = Tensor(positive_array) + negative = Tensor(negative_array) + mind_margin = Tensor(margin) + triplet_margin_loss = NetTripletMarginLoss(p=p, swap=swap, reduction=reduction, eps=eps) + output_ms = triplet_margin_loss(anchor, positive, negative, mind_margin) + print(output_ms) + torch_anchor = torch.tensor(anchor_array) + torch_positive = torch.tensor(positive_array) + torch_negative = torch.tensor(negative_array) + torch_loss = torch.nn.TripletMarginLoss(margin=margin, p=p, swap=swap, reduction=reduction, eps=eps) + expect = torch_loss(torch_anchor, torch_positive, torch_negative) + assert np.allclose(output_ms.asnumpy(), + expect.numpy(), + rtol=1e-5, + atol=1e-5, + equal_nan=False) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_triplet_margin_loss_float32(): + """ + Feature: Input type of float32 + Description: Input type of [float32, float32, float32, float32]. + Expectation: success. + """ + for mode in [context.GRAPH_MODE, context.PYNATIVE_MODE]: + context.set_context(mode=mode, device_target="GPU") + data_type = np.float32 + anchor_array = np.array([[1.3, 20.5, 5.6], + [3.5, 4.8, 7.2], + [0.2, 0.01, 1], + [4, 4.1, 20]]).astype(data_type) + positive_array = np.array([[2., 10., 1.], + [6., 7., 10.], + [13., 4., 1.], + [0.33, -4, -1.5]]).astype(data_type) + negative_array = np.array([[2., 21., 6.], + [68., 9., 10.], + [131., 25., 16.], + [0.31, -0.14, -16.]]).astype(data_type) + margin = 2.0 + p = 1 + swap = False + reduction = "none" + eps = 1e-6 + + anchor = Tensor(anchor_array) + positive = Tensor(positive_array) + negative = Tensor(negative_array) + mind_margin = Tensor(margin) + triplet_margin_loss = NetTripletMarginLoss(p=p, swap=swap, reduction=reduction, eps=eps) + output_ms = triplet_margin_loss(anchor, positive, negative, mind_margin) + torch_anchor = torch.tensor(anchor_array) + torch_positive = torch.tensor(positive_array) + torch_negative = torch.tensor(negative_array) + torch_loss = torch.nn.TripletMarginLoss(margin=margin, p=p, swap=swap, reduction=reduction, eps=eps) + expect = torch_loss(torch_anchor, torch_positive, torch_negative) + assert np.allclose(output_ms.asnumpy(), + expect.numpy(), + rtol=1e-4, + atol=1e-4, + equal_nan=False)