!34880 [assistant][ops][I4ZZV6] New GPU operator implementation, include TripletMarginLoss

Merge pull request !34880 from 黎冠新/TripletMarginLoss2
This commit is contained in:
i-robot 2022-07-21 01:54:26 +00:00 committed by Gitee
commit 0ab8fa7dab
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 1558 additions and 3 deletions

View File

@ -0,0 +1,516 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_TRIPLET_MARGIN_LOSS_HELPER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_TRIPLET_MARGIN_LOSS_HELPER_H_
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/triplet_margin_loss_impl.cuh"
namespace mindspore {
namespace cukernel {
class TripletMarginLossAttr : public GpuKernelAttrBase {
public:
TripletMarginLossAttr() = default;
~TripletMarginLossAttr() override = default;
int64_t p;
bool swap;
float eps;
std::string reduction;
};
template <typename T, typename M, typename S>
class TripletMarginLossHelperGpuKernel : public GpuKernelHelperBase {
public:
explicit TripletMarginLossHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
: GpuKernelHelperBase(kernel_name, device_id) {
shape_size_ = 0;
need_broadcast_ = false;
is_null_input_ = false;
reduction_ = "mean";
bound_ = 0;
}
virtual ~TripletMarginLossHelperGpuKernel() = default;
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::vector<int64_t>> &output_shapes) override {
constexpr size_t INPUT_NUM_1 = 3;
constexpr size_t OUTPUT_NUM = 1;
constexpr int64_t kzero = 0;
constexpr int64_t kone = 1;
constexpr int64_t ktwo = 2;
constexpr int64_t kthree = 3;
ResetResource();
dst_shape_.clear();
std::vector<std::vector<int64_t>> input_shapes_1;
input_shapes_1.emplace_back(input_shapes[kzero]);
input_shapes_1.emplace_back(input_shapes[kone]);
input_shapes_1.emplace_back(input_shapes[ktwo]);
int inp_flag =
CalShapesSizeInBytes<T>(input_shapes_1, INPUT_NUM_1, kernel_name_, "input_shapes_1", &input_size_list_);
if (inp_flag == -1) {
return inp_flag;
}
input_size_list_.emplace_back(sizeof(M));
int out_flag =
CalShapesSizeInBytes<S>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
if (out_flag == -1) {
return out_flag;
}
is_null_input_ = (inp_flag == 1 || out_flag == 1);
anchor_shape_ = input_shapes[kzero];
positive_shape_ = input_shapes[kone];
negative_shape_ = input_shapes[ktwo];
size_t dim_x = anchor_shape_.size();
size_t dim_positive = positive_shape_.size();
size_t dim_negative = negative_shape_.size();
shape_size_ = std::max(std::max(dim_x, dim_positive), dim_negative);
std::reverse(anchor_shape_.begin(), anchor_shape_.end());
std::reverse(positive_shape_.begin(), positive_shape_.end());
std::reverse(negative_shape_.begin(), negative_shape_.end());
if (dim_x < shape_size_) anchor_shape_.resize(shape_size_, kone);
if (dim_positive < shape_size_) positive_shape_.resize(shape_size_, kone);
if (dim_negative < shape_size_) negative_shape_.resize(shape_size_, kone);
std::reverse(anchor_shape_.begin(), anchor_shape_.end());
std::reverse(positive_shape_.begin(), positive_shape_.end());
std::reverse(negative_shape_.begin(), negative_shape_.end());
if (anchor_shape_ != positive_shape_ || anchor_shape_ != negative_shape_ || positive_shape_ != negative_shape_) {
need_broadcast_ = true;
}
int64_t tem_shape_size = 0;
for (size_t i = 0; i < shape_size_; i++) {
tem_shape_size++;
dst_shape_.push_back((int64_t)std::max(std::max(anchor_shape_[i], positive_shape_[i]), negative_shape_[i]));
}
std::vector<std::vector<int64_t>> workspace_shapes_sizet;
std::vector<std::vector<int64_t>> workspace_shapes_S;
constexpr size_t WORKSPACE_SIZET_NUM = 1;
constexpr size_t WORKSPACE_S_NUM = 1;
std::vector<int64_t> shape_shape;
tem_shape_size *= 4; // store 4 shapes
shape_shape.push_back(tem_shape_size);
workspace_shapes_sizet.emplace_back(shape_shape);
swap_ = attr_ptr_->swap;
std::vector<int64_t> tem_output_shape(dst_shape_);
tem_output_shape.erase(tem_output_shape.begin() + 1);
if (swap_) {
tem_output_shape.insert(tem_output_shape.begin(), kthree);
} else {
tem_output_shape.insert(tem_output_shape.begin(), ktwo);
}
workspace_shapes_S.emplace_back(tem_output_shape);
int work_flag = CalShapesSizeInBytes<int64_t>(workspace_shapes_sizet, WORKSPACE_SIZET_NUM, kernel_name_,
"workspace_shapes", &work_size_list_);
if (work_flag == -1) {
return work_flag;
}
work_flag =
CalShapesSizeInBytes<S>(workspace_shapes_S, WORKSPACE_S_NUM, kernel_name_, "workspace_shapes", &work_size_list_);
if (work_flag == -1) {
return work_flag;
}
size_t workspace_boundlist = kthree * sizeof(size_t);
work_size_list_.emplace_back(workspace_boundlist);
if (need_broadcast_) {
std::vector<std::vector<int64_t>> workspace_shapes_T;
constexpr size_t WORKSPACE_T_NUM = 1;
workspace_shapes_T.emplace_back(dst_shape_);
workspace_shapes_T[0].insert(workspace_shapes_T[0].begin(), kthree);
work_flag = CalShapesSizeInBytes<T>(workspace_shapes_T, WORKSPACE_T_NUM, kernel_name_, "workspace_shapes",
&work_size_list_);
if (work_flag == -1) {
return work_flag;
}
}
return CheckKernelParam();
}
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
const int64_t kzero = 0;
const int64_t kone = 1;
const int64_t ktwo = 2;
const int64_t kthree = 3;
const int64_t kfour = 4;
if (is_null_input_) {
return 0;
}
bound_list_[kzero] = ChooseBound(anchor_shape_[kone], positive_shape_[kone], dst_shape_[kone]);
bound_list_[kone] = ChooseBound(anchor_shape_[kone], negative_shape_[kone], dst_shape_[kone]);
bound_list_[ktwo] = ChooseBound(positive_shape_[kone], negative_shape_[kone], dst_shape_[kone]);
bound_ = dst_shape_[kone];
size_t outer_size = dst_shape_[kzero];
size_t inner_size = 1;
for (size_t i = 2; i < shape_size_; i++) {
inner_size *= dst_shape_[i];
}
T *anchor_ptr = nullptr;
T *positive_ptr = nullptr;
T *negative_ptr = nullptr;
M *margin_ptr = nullptr;
S *output_ptr = nullptr;
int flag = GetDeviceAddress<T>(input_ptrs, kzero, kernel_name_, &anchor_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(input_ptrs, kone, kernel_name_, &positive_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(input_ptrs, ktwo, kernel_name_, &negative_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<M>(input_ptrs, kthree, kernel_name_, &margin_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<S>(output_ptrs, kzero, kernel_name_, &output_ptr);
if (flag != 0) {
return flag;
}
std::vector<int64_t> input_shapes;
input_shapes.insert(input_shapes.end(), anchor_shape_.begin(), anchor_shape_.end());
input_shapes.insert(input_shapes.end(), positive_shape_.begin(), positive_shape_.end());
input_shapes.insert(input_shapes.end(), negative_shape_.begin(), negative_shape_.end());
input_shapes.insert(input_shapes.end(), dst_shape_.begin(), dst_shape_.end());
int64_t *anchor_shape_ptr = nullptr, *dst_shape_ptr = nullptr;
S *tem_output_ptr = nullptr;
size_t *bound_list_ptr = nullptr;
T *anchor_broadcast_ptr = anchor_ptr, *positive_broadcast_ptr = positive_ptr,
*negative_broadcast_ptr = positive_ptr;
flag = GetDeviceAddress<int64_t>(work_ptrs, kzero, kernel_name_, &anchor_shape_ptr);
if (flag != 0) {
return flag;
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(anchor_shape_ptr, &input_shapes[kzero], shape_size_ * sizeof(int64_t) * kfour,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(cuda_stream)),
"cudaMemcpyAsync workspace failed");
dst_shape_ptr = anchor_shape_ptr + kthree * shape_size_;
flag = GetDeviceAddress<S>(work_ptrs, kone, kernel_name_, &tem_output_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<size_t>(work_ptrs, ktwo, kernel_name_, &bound_list_ptr);
if (flag != 0) {
return flag;
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(bound_list_ptr, &bound_list_[kzero], sizeof(size_t) * kthree, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream)),
"cudaMemcpyAsync workspace failed");
if (need_broadcast_) {
flag = GetDeviceAddress<T>(work_ptrs, kthree, kernel_name_, &anchor_broadcast_ptr);
if (flag != 0) {
return flag;
}
positive_broadcast_ptr = anchor_broadcast_ptr + bound_ * outer_size * inner_size;
negative_broadcast_ptr = positive_broadcast_ptr + bound_ * outer_size * inner_size;
}
CalTripletMarginLoss(anchor_ptr, positive_ptr, negative_ptr, anchor_broadcast_ptr, positive_broadcast_ptr,
negative_broadcast_ptr, output_ptr, tem_output_ptr, anchor_shape_ptr, dst_shape_ptr,
outer_size, inner_size, bound_list_ptr, bound_, shape_size_, margin_ptr, attr_ptr_->p,
attr_ptr_->eps, reduction_, swap_, need_broadcast_, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream));
return 0;
}
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
attr_ptr_ = std::dynamic_pointer_cast<TripletMarginLossAttr>(kernel_attr);
}
protected:
int CheckKernelParam() override {
std::string reduction_list = "[mean,none,sum]";
reduction_ = attr_ptr_->reduction;
if (reduction_ != "mean" && reduction_ != "none" && reduction_ != "sum") {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'reduciton' should be in " << reduction_list
<< "but got:" << reduction_;
return -1;
}
return 0;
}
size_t ChooseBound(size_t src_bound_first, size_t src_bound_second, size_t dst_bound) {
if (src_bound_first == 1 && src_bound_second == 1 && dst_bound != 1) {
return 1;
}
return dst_bound;
}
private:
std::shared_ptr<TripletMarginLossAttr> attr_ptr_;
std::vector<int64_t> anchor_shape_, positive_shape_, negative_shape_, dst_shape_;
size_t shape_size_;
size_t bound_list_[3];
size_t bound_;
bool need_broadcast_;
bool swap_;
bool is_null_input_;
std::string reduction_;
};
// half
template <typename M>
class TripletMarginLossHelperGpuKernel<half, M, half> : public GpuKernelHelperBase {
public:
explicit TripletMarginLossHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
: GpuKernelHelperBase(kernel_name, device_id) {
shape_size_ = 0;
need_broadcast_ = false;
is_null_input_ = false;
reduction_ = "mean";
bound_ = 0;
}
virtual ~TripletMarginLossHelperGpuKernel() = default;
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::vector<int64_t>> &output_shapes) override {
constexpr size_t INPUT_NUM_1 = 3;
constexpr size_t OUTPUT_NUM = 1;
constexpr int64_t kzero = 0;
constexpr int64_t kone = 1;
constexpr int64_t ktwo = 2;
constexpr int64_t kthree = 3;
ResetResource();
dst_shape_.clear();
std::vector<std::vector<int64_t>> input_shapes_1;
input_shapes_1.emplace_back(input_shapes[kzero]);
input_shapes_1.emplace_back(input_shapes[kone]);
input_shapes_1.emplace_back(input_shapes[ktwo]);
int inp_flag =
CalShapesSizeInBytes<half>(input_shapes_1, INPUT_NUM_1, kernel_name_, "input_shapes_1", &input_size_list_);
if (inp_flag == -1) {
return inp_flag;
}
input_size_list_.emplace_back(ktwo * sizeof(M));
int out_flag =
CalShapesSizeInBytes<half>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
if (out_flag == -1) {
return out_flag;
}
is_null_input_ = (inp_flag == 1 || out_flag == 1);
anchor_shape_ = input_shapes[kzero];
positive_shape_ = input_shapes[kone];
negative_shape_ = input_shapes[ktwo];
size_t dim_x = anchor_shape_.size();
size_t dim_positive = positive_shape_.size();
size_t dim_negative = negative_shape_.size();
shape_size_ = std::max(std::max(dim_x, dim_positive), dim_negative);
std::reverse(anchor_shape_.begin(), anchor_shape_.end());
std::reverse(positive_shape_.begin(), positive_shape_.end());
std::reverse(negative_shape_.begin(), negative_shape_.end());
if (dim_x < shape_size_) anchor_shape_.resize(shape_size_, kone);
if (dim_positive < shape_size_) positive_shape_.resize(shape_size_, kone);
if (dim_negative < shape_size_) negative_shape_.resize(shape_size_, kone);
std::reverse(anchor_shape_.begin(), anchor_shape_.end());
std::reverse(positive_shape_.begin(), positive_shape_.end());
std::reverse(negative_shape_.begin(), negative_shape_.end());
if (anchor_shape_ != positive_shape_ || anchor_shape_ != negative_shape_ || positive_shape_ != negative_shape_) {
need_broadcast_ = true;
}
int64_t tem_shape_size = 0;
for (size_t i = 0; i < shape_size_; i++) {
tem_shape_size++;
dst_shape_.push_back((int64_t)std::max(std::max(anchor_shape_[i], positive_shape_[i]), negative_shape_[i]));
}
std::vector<std::vector<int64_t>> workspace_shapes_sizet;
std::vector<std::vector<int64_t>> workspace_shapes_S;
constexpr size_t WORKSPACE_SIZET_NUM = 1;
constexpr size_t WORKSPACE_S_NUM = 1;
std::vector<int64_t> shape_shape;
tem_shape_size *= 4; // store 4 shapes
shape_shape.push_back(tem_shape_size);
workspace_shapes_sizet.emplace_back(shape_shape);
swap_ = attr_ptr_->swap;
std::vector<int64_t> tem_output_shape(dst_shape_);
tem_output_shape.erase(tem_output_shape.begin() + 1);
if (swap_) {
tem_output_shape.insert(tem_output_shape.begin(), kthree);
} else {
tem_output_shape.insert(tem_output_shape.begin(), ktwo);
}
workspace_shapes_S.emplace_back(tem_output_shape);
int work_flag = CalShapesSizeInBytes<int64_t>(workspace_shapes_sizet, WORKSPACE_SIZET_NUM, kernel_name_,
"workspace_shapes", &work_size_list_);
if (work_flag == -1) {
return work_flag;
}
work_flag = CalShapesSizeInBytes<float>(workspace_shapes_S, WORKSPACE_S_NUM, kernel_name_, "workspace_shapes",
&work_size_list_);
if (work_flag == -1) {
return work_flag;
}
size_t workspace_boundlist = kthree * sizeof(size_t);
work_size_list_.emplace_back(workspace_boundlist);
if (need_broadcast_) {
std::vector<std::vector<int64_t>> workspace_shapes_T;
constexpr size_t WORKSPACE_T_NUM = 1;
workspace_shapes_T.emplace_back(dst_shape_);
workspace_shapes_T[0].insert(workspace_shapes_T[0].begin(), kthree);
work_flag = CalShapesSizeInBytes<half>(workspace_shapes_T, WORKSPACE_T_NUM, kernel_name_, "workspace_shapes",
&work_size_list_);
if (work_flag == -1) {
return work_flag;
}
}
return CheckKernelParam();
}
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
const int64_t kzero = 0;
const int64_t kone = 1;
const int64_t ktwo = 2;
const int64_t kthree = 3;
const int64_t kfour = 4;
if (is_null_input_) {
return 0;
}
bound_list_[kzero] = ChooseBound(anchor_shape_[kone], positive_shape_[kone], dst_shape_[kone]);
bound_list_[kone] = ChooseBound(anchor_shape_[kone], negative_shape_[kone], dst_shape_[kone]);
bound_list_[ktwo] = ChooseBound(positive_shape_[kone], negative_shape_[kone], dst_shape_[kone]);
bound_ = dst_shape_[kone];
size_t outer_size = dst_shape_[kzero];
size_t inner_size = 1;
for (size_t i = 2; i < shape_size_; i++) {
inner_size *= dst_shape_[i];
}
half *anchor_ptr = nullptr;
half *positive_ptr = nullptr;
half *negative_ptr = nullptr;
M *margin_ptr = nullptr;
half *output_ptr = nullptr;
int flag = GetDeviceAddress<half>(input_ptrs, kzero, kernel_name_, &anchor_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<half>(input_ptrs, kone, kernel_name_, &positive_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<half>(input_ptrs, ktwo, kernel_name_, &negative_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<M>(input_ptrs, kthree, kernel_name_, &margin_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<half>(output_ptrs, kzero, kernel_name_, &output_ptr);
if (flag != 0) {
return flag;
}
std::vector<int64_t> input_shapes;
input_shapes.insert(input_shapes.end(), anchor_shape_.begin(), anchor_shape_.end());
input_shapes.insert(input_shapes.end(), positive_shape_.begin(), positive_shape_.end());
input_shapes.insert(input_shapes.end(), negative_shape_.begin(), negative_shape_.end());
input_shapes.insert(input_shapes.end(), dst_shape_.begin(), dst_shape_.end());
int64_t *anchor_shape_ptr = nullptr, *dst_shape_ptr = nullptr;
float *tem_output_ptr = nullptr;
size_t *bound_list_ptr = nullptr;
half *anchor_broadcast_ptr = anchor_ptr, *positive_broadcast_ptr = positive_ptr,
*negative_broadcast_ptr = positive_ptr;
flag = GetDeviceAddress<int64_t>(work_ptrs, kzero, kernel_name_, &anchor_shape_ptr);
if (flag != 0) {
return flag;
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(anchor_shape_ptr, &input_shapes[kzero], shape_size_ * sizeof(int64_t) * kfour,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(cuda_stream)),
"cudaMemcpyAsync workspace failed");
dst_shape_ptr = anchor_shape_ptr + kthree * shape_size_;
flag = GetDeviceAddress<float>(work_ptrs, kone, kernel_name_, &tem_output_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<size_t>(work_ptrs, ktwo, kernel_name_, &bound_list_ptr);
if (flag != 0) {
return flag;
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(bound_list_ptr, &bound_list_[kzero], sizeof(size_t) * kthree, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream)),
"cudaMemcpyAsync workspace failed");
if (need_broadcast_) {
flag = GetDeviceAddress<half>(work_ptrs, kthree, kernel_name_, &anchor_broadcast_ptr);
if (flag != 0) {
return flag;
}
positive_broadcast_ptr = anchor_broadcast_ptr + bound_ * outer_size * inner_size;
negative_broadcast_ptr = positive_broadcast_ptr + bound_ * outer_size * inner_size;
}
CalTripletMarginLoss(anchor_ptr, positive_ptr, negative_ptr, anchor_broadcast_ptr, positive_broadcast_ptr,
negative_broadcast_ptr, output_ptr, tem_output_ptr, anchor_shape_ptr, dst_shape_ptr,
outer_size, inner_size, bound_list_ptr, bound_, shape_size_, margin_ptr, attr_ptr_->p,
attr_ptr_->eps, reduction_, swap_, need_broadcast_, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream));
return 0;
}
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
attr_ptr_ = std::dynamic_pointer_cast<TripletMarginLossAttr>(kernel_attr);
}
protected:
int CheckKernelParam() override {
std::string reduction_list = "[mean,none,sum]";
reduction_ = attr_ptr_->reduction;
if (reduction_ != "mean" && reduction_ != "none" && reduction_ != "sum") {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'reduciton' should be in " << reduction_list
<< "but got:" << reduction_;
return -1;
}
return 0;
}
size_t ChooseBound(size_t src_bound_first, size_t src_bound_second, size_t dst_bound) {
if (src_bound_first == 1 && src_bound_second == 1 && dst_bound != 1) {
return 1;
}
return dst_bound;
}
private:
std::shared_ptr<TripletMarginLossAttr> attr_ptr_;
std::vector<int64_t> anchor_shape_, positive_shape_, negative_shape_, dst_shape_;
size_t shape_size_;
size_t bound_list_[3];
size_t bound_;
bool need_broadcast_;
bool swap_;
bool is_null_input_;
std::string reduction_;
};
} // namespace cukernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_TRIPLET_MARGIN_LOSS_HELPER_H_

View File

@ -0,0 +1,508 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "triplet_margin_loss_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
__device__ __forceinline__ int64_t Index(const int64_t &index, const int64_t &dim) { return dim == 1 ? 0 : index; }
template <typename T>
__global__ void FillAndBroadcast(const int64_t size, const size_t shape_size,
const int64_t *tensor_shapes, const int64_t *dst_shape,
const T *anchor, const T *positive, const T *negative,
T *anchor_broadcast) {
const T *pair_tensor[3] = {anchor, positive, negative};
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < 3*size; pos += blockDim.x * gridDim.x) {
const size_t mode = pos/size;
const int64_t *src_shape = tensor_shapes + shape_size * mode;
size_t tmp_pos = pos % size;
size_t pos_size = size / dst_shape[0];
size_t dst_index_array[8];
dst_index_array[0] = tmp_pos / pos_size;
for (size_t i = 1; i < shape_size; i++) {
tmp_pos -= dst_index_array[i - 1] * pos_size;
pos_size = pos_size / dst_shape[i];
dst_index_array[i] = tmp_pos / pos_size;
}
size_t src_size = 1;
for (size_t i = 0; i < shape_size; i++) {
src_size *= src_shape[i];
}
size_t src_pos = 0;
for (size_t i = 0; i < shape_size; i++) {
src_size /= src_shape[i];
size_t length_by_index = Index(dst_index_array[i], src_shape[i]) * src_size;
src_pos += length_by_index;
}
(anchor_broadcast + mode * size)[pos % size] = pair_tensor[mode][src_pos];
}
return;
}
template <typename T, typename S>
__global__ void PairwiseDistance(const T *anchor, const T *positive, const T *negative,
const size_t *bound_list, const size_t bound, const size_t outer_size,
const size_t inner_size, S *tem_output, const size_t n, const int64_t p,
const float eps) {
const T *pair_tensor[3] = {anchor, positive, negative};
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
pos < n * outer_size * inner_size; pos += gridDim.x * blockDim.x) {
size_t mode = pos / (outer_size * inner_size);
size_t idx = pos % (outer_size * inner_size);
S res = 0;
size_t x = idx / inner_size % outer_size;
size_t y = idx % inner_size;
for (int i = 0; i < bound_list[mode]; i++) {
size_t input_offset = x * bound * inner_size + i * inner_size + y;
S base =
abs(static_cast<T>(pair_tensor[mode / 2][input_offset] - pair_tensor[(mode + 3) / 2][input_offset]) + eps);
S tem = pow(base, static_cast<S>(p));
res += tem;
}
tem_output[pos] = pow(res, static_cast<S>(1.0 / p));
}
return;
}
template <typename S>
__global__ void PairwiseDistancePzero(const size_t *bound_list, const size_t output_size,
S *tem_output, const size_t n) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < n * output_size; pos += gridDim.x * blockDim.x) {
size_t mode = pos / output_size;
tem_output[pos] = static_cast<S>(bound_list[mode]);
}
return;
}
template <typename S>
__global__ void SwapTrue(S *tem_output, const size_t output_size) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size; pos += gridDim.x * blockDim.x) {
tem_output[pos + output_size] = tem_output[pos + output_size] > tem_output[pos + 2 * output_size] ?
tem_output[pos + 2 * output_size] : tem_output[pos + output_size];
}
return;
}
template <typename M, typename S>
__global__ void MaxReduction(S *tem_output, S *output, const size_t output_size, const M *margin) {
S lower_bound = 0;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size; pos += gridDim.x * blockDim.x) {
output[pos] = max(static_cast<float>(margin[0]) + tem_output[pos] - tem_output[pos + output_size], lower_bound);
}
return;
}
template <typename S>
__global__ void AddTile(S *tmp_loss, size_t index) {
tmp_loss[0] += tmp_loss[index];
}
template <typename S>
__global__ void PartialSum(S *tmp_loss, size_t stride) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < stride; i += blockDim.x * gridDim.x) {
tmp_loss[i] += tmp_loss[i + stride];
}
}
template <typename S>
__global__ void ReductionDivde(S *output, S *tem_output, const size_t k) {
output[0] = tem_output[0] / k;
}
// half
template <>
__global__ void PairwiseDistance(const half *anchor, const half *positive, const half *negative,
const size_t *bound_list, const size_t bound, const size_t outer_size,
const size_t inner_size, float *tem_output, const size_t n, const int64_t p,
const float eps) {
const half *pair_tensor[3] = {anchor, positive, negative};
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
pos < n * outer_size * inner_size; pos += gridDim.x * blockDim.x) {
size_t mode = pos / (outer_size * inner_size);
size_t idx = pos % (outer_size * inner_size);
float res = 0;
size_t x = idx / inner_size % outer_size;
size_t y = idx % inner_size;
for (int i = 0; i < bound_list[mode]; i++) {
size_t input_offset = x * bound * inner_size + i * inner_size + y;
float base = abs(__half2float(pair_tensor[mode / 2][input_offset]) -
__half2float(pair_tensor[(mode+3) / 2][input_offset]) + eps);
float tem = pow(base, p);
res += tem;
}
tem_output[pos] = pow(res, 1.0 / p);
}
return;
}
// half
template <typename M>
__global__ void MaxReduction(float *tem_output, half *output, const size_t output_size, const M *margin) {
float lower_bound = 0;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size; pos += gridDim.x * blockDim.x) {
output[pos] = __float2half(max(margin[0] + tem_output[pos] - tem_output[pos + output_size], lower_bound));
}
return;
}
// half
__global__ void ReductionDivde(half *output, float *tem_output, const size_t k) {
output[0] = __float2half((tem_output[0] / k));
}
// Complex
template <typename S>
__global__ void PairwiseDistance(const Complex<S> *anchor, const Complex<S> *positive,
const Complex<S> *negative, const size_t *bound_list,
const size_t bound, const size_t outer_size, const size_t inner_size,
S *tem_output, const size_t n, const int64_t p, const float eps) {
const Complex<S> *pair_tensor[3] = {anchor, positive, negative};
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
pos < n * outer_size * inner_size; pos += gridDim.x * blockDim.x) {
size_t mode = pos / (outer_size * inner_size);
size_t idx = pos % (outer_size * inner_size);
S res = 0;
size_t x = idx / inner_size % outer_size;
size_t y = idx % inner_size;
for (int i = 0; i < bound_list[mode]; i++) {
size_t input_offset = x * bound * inner_size + i * inner_size + y;
Complex<S> base_complex =
pair_tensor[mode / 2][input_offset] - pair_tensor[(mode + 3) / 2][input_offset] + static_cast<S>(eps);
S base = sqrt((base_complex.real() * base_complex.real() + base_complex.imag() * base_complex.imag()));
S tem = pow(base, static_cast<S>(p));
res += tem;
}
tem_output[pos] = pow(res, 1.0 / p);
}
return;
}
template <typename T, typename M, typename S, typename H>
void CalTripletMarginLoss(const T *anchor, const T *positive, const T *negative, T *anchor_broadcast,
T *positive_broadcast, T *negative_broadcast, S *output, H *tem_output,
const int64_t *tensor_shapes, const int64_t *dst_shape, const size_t outer_size,
const size_t inner_size, const size_t *bound_list, const size_t bound,
const size_t shape_size, M *margin, const int64_t p, const float eps,
const std::string reduction, const bool swap, const bool need_broadcast,
const uint32_t &device_id, cudaStream_t cuda_stream) {
const int64_t size = outer_size * inner_size * bound;
size_t n;
if (swap)
n = 3;
else
n = 2;
const size_t output_size = outer_size * inner_size;
if (p == 0) {
PairwiseDistancePzero<<<CUDA_BLOCKS(device_id, n * output_size), CUDA_THREADS(device_id), 0, cuda_stream>>>
(bound_list, output_size, tem_output, n);
} else if (need_broadcast) {
FillAndBroadcast<<<CUDA_BLOCKS(device_id, 3 * size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size, shape_size, tensor_shapes, dst_shape, anchor, positive, negative,
anchor_broadcast);
PairwiseDistance<<<CUDA_BLOCKS(device_id, n * output_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
anchor_broadcast, positive_broadcast, negative_broadcast, bound_list, bound, outer_size,
inner_size, tem_output, n, p, eps);
} else {
PairwiseDistance<<<CUDA_BLOCKS(device_id, n * output_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
anchor, positive, negative, bound_list, bound, outer_size, inner_size, tem_output, n, p, eps);
}
if (swap) {
SwapTrue<<<CUDA_BLOCKS(device_id, output_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(tem_output,
output_size);
}
if (reduction == "none") {
MaxReduction<<<CUDA_BLOCKS(device_id, output_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(tem_output,
output, output_size, margin);
} else {
MaxReduction<<<CUDA_BLOCKS(device_id, output_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(tem_output,
tem_output, output_size, margin);
if (output_size % 2 == 1 && output_size != 1) {
AddTile<<<1, 1, 0, cuda_stream>>>(tem_output, output_size - 1);
}
for (size_t stride = output_size / 2; stride > 0; stride >>= 1) {
PartialSum<<<CUDA_BLOCKS(device_id, stride), CUDA_THREADS(device_id), 0, cuda_stream>>>(tem_output, stride);
if (stride > 2 && stride % 2 == 1) {
AddTile<<<1, 1, 0, cuda_stream>>>(tem_output, stride - 1);
}
}
if (reduction == "mean") {
ReductionDivde<<<1, 1, 0, cuda_stream>>>(output, tem_output, output_size);
} else {
ReductionDivde<<<1, 1, 0, cuda_stream>>>(output, tem_output, 1);
}
}
return;
}
template CUDA_LIB_EXPORT void CalTripletMarginLoss<int8_t, float, float, float>(
const int8_t *anchor, const int8_t *positive, const int8_t *negative,
int8_t *anchor_broadcast, int8_t *positive_broadcast, int8_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<int16_t, float, float, float>(
const int16_t *anchor, const int16_t *positive, const int16_t *negative,
int16_t *anchor_broadcast, int16_t *positive_broadcast, int16_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<int32_t, float, float, float>(
const int32_t *anchor, const int32_t *positive, const int32_t *negative,
int32_t *anchor_broadcast, int32_t *positive_broadcast, int32_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<int64_t, float, float, float>(
const int64_t *anchor, const int64_t *positive, const int64_t *negative,
int64_t *anchor_broadcast, int64_t *positive_broadcast, int64_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<uint8_t, float, float, float>(
const uint8_t *anchor, const uint8_t *positive, const uint8_t *negative,
uint8_t *anchor_broadcast, uint8_t *positive_broadcast, uint8_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<uint16_t, float, float, float>(
const uint16_t *anchor, const uint16_t *positive, const uint16_t *negative,
uint16_t *anchor_broadcast, uint16_t *positive_broadcast, uint16_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<uint32_t, float, float, float>(
const uint32_t *anchor, const uint32_t *positive, const uint32_t *negative,
uint32_t *anchor_broadcast, uint32_t *positive_broadcast, uint32_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<uint64_t, float, float, float>(
const uint64_t *anchor, const uint64_t *positive, const uint64_t *negative,
uint64_t *anchor_broadcast, uint64_t *positive_broadcast, uint64_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<double, float, double, double>(
const double *anchor, const double *positive, const double *negative,
double *anchor_broadcast, double *positive_broadcast, double *negative_broadcast,
double *output, double *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<float, float, float, float>(
const float *anchor, const float *positive, const float *negative,
float *anchor_broadcast, float *positive_broadcast, float *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<half, float, half, float>(
const half *anchor, const half *positive, const half *negative,
half *anchor_broadcast, half *positive_broadcast, half *negative_broadcast,
half *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, float *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void
CalTripletMarginLoss<Complex<float>, float, float, float>(
const Complex<float> *anchor, const Complex<float> *positive, const Complex<float> *negative,
Complex<float> *anchor_broadcast, Complex<float> *positive_broadcast, Complex<float> *negative_broadcast,
float *output, float *tem_output,
const int64_t *tensor_shapes, const int64_t *dst_shape,
const size_t outer_size, const size_t inner_size, const size_t *bound_list, const size_t bound,
const size_t shape_size, float *margin, const int64_t p,
const float eps, const std::string reduction, const bool swap,
const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void
CalTripletMarginLoss<Complex<double>, float, double, double>(
const Complex<double> *anchor, const Complex<double> *positive,
const Complex<double> *negative,
Complex<double> *anchor_broadcast, Complex<double> *positive_broadcast, Complex<double> *negative_broadcast,
double *output, double *tem_output,
const int64_t *tensor_shapes, const int64_t *dst_shape,
const size_t outer_size, const size_t inner_size, const size_t *bound_list, const size_t bound,
const size_t shape_size, float *margin, const int64_t p,
const float eps, const std::string reduction, const bool swap,
const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<int8_t, double, float, float>(
const int8_t *anchor, const int8_t *positive, const int8_t *negative,
int8_t *anchor_broadcast, int8_t *positive_broadcast, int8_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<int16_t, double, float, float>(
const int16_t *anchor, const int16_t *positive, const int16_t *negative,
int16_t *anchor_broadcast, int16_t *positive_broadcast, int16_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<int32_t, double, float, float>(
const int32_t *anchor, const int32_t *positive, const int32_t *negative,
int32_t *anchor_broadcast, int32_t *positive_broadcast, int32_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<int64_t, double, float, float>(
const int64_t *anchor, const int64_t *positive, const int64_t *negative,
int64_t *anchor_broadcast, int64_t *positive_broadcast, int64_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<uint8_t, double, float, float>(
const uint8_t *anchor, const uint8_t *positive, const uint8_t *negative,
uint8_t *anchor_broadcast, uint8_t *positive_broadcast, uint8_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<uint16_t, double, float, float>(
const uint16_t *anchor, const uint16_t *positive, const uint16_t *negative,
uint16_t *anchor_broadcast, uint16_t *positive_broadcast, uint16_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<uint32_t, double, float, float>(
const uint32_t *anchor, const uint32_t *positive, const uint32_t *negative,
uint32_t *anchor_broadcast, uint32_t *positive_broadcast, uint32_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<uint64_t, double, float, float>(
const uint64_t *anchor, const uint64_t *positive, const uint64_t *negative,
uint64_t *anchor_broadcast, uint64_t *positive_broadcast, uint64_t *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<double, double, double, double>(
const double *anchor, const double *positive, const double *negative,
double *anchor_broadcast, double *positive_broadcast, double *negative_broadcast,
double *output, double *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<float, double, float, float>(
const float *anchor, const float *positive, const float *negative,
float *anchor_broadcast, float *positive_broadcast, float *negative_broadcast,
float *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalTripletMarginLoss<half, double, half, float>(
const half *anchor, const half *positive, const half *negative,
half *anchor_broadcast, half *positive_broadcast, half *negative_broadcast,
half *output, float *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size,
const size_t *bound_list, const size_t bound, const size_t shape_size, double *margin,
const int64_t p, const float eps, const std::string reduction,
const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void
CalTripletMarginLoss<Complex<float>, double, float, float>(
const Complex<float> *anchor, const Complex<float> *positive,
const Complex<float> *negative, Complex<float> *anchor_broadcast, Complex<float> *positive_broadcast,
Complex<float> *negative_broadcast, float *output, float *tem_output,
const int64_t *tensor_shapes, const int64_t *dst_shape,
const size_t outer_size, const size_t inner_size, const size_t *bound_list, const size_t bound,
const size_t shape_size, double *margin, const int64_t p,
const float eps, const std::string reduction, const bool swap,
const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void
CalTripletMarginLoss<Complex<double>, double, double, double>(
const Complex<double> *anchor, const Complex<double> *positive,
const Complex<double> *negative, Complex<double> *anchor_broadcast, Complex<double> *positive_broadcast,
Complex<double> *negative_broadcast, double *output, double *tem_output,
const int64_t *tensor_shapes, const int64_t *dst_shape,
const size_t outer_size, const size_t inner_size, const size_t *bound_list, const size_t bound,
const size_t shape_size, double *margin, const int64_t p,
const float eps, const std::string reduction, const bool swap,
const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,37 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIPLET_MARGIN_LOSS_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIPLET_MARGIN_LOSS_IMPL_CUH_
#include <string>
#include <algorithm>
#include "include/cuda_fp16.h"
#include "mindapi/base/types.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
template <typename T> using Complex = mindspore::utils::Complex<T>;
template <typename T, typename M, typename S, typename H>
CUDA_LIB_EXPORT void CalTripletMarginLoss(const T *anchor, const T *positive, const T *negative, T *anchor_broadcast,
T *positive_broadcast, T *negative_broadcast, S *output, H *tem_output, const int64_t *tensor_shapes,
const int64_t *dst_shape, const size_t outer_size, const size_t inner_size, const size_t *bound_list,
const size_t bound, const size_t shape_size, M *margin, const int64_t p, const float eps,
const std::string reduction, const bool swap, const bool need_broadcast, const uint32_t &device_id,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_TRIPLET_MARGIN_LOSS_IMPL_CUH_

View File

@ -0,0 +1,294 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/nn/triplet_margin_loss_gpu_kernel.h"
#include <utility>
namespace mindspore {
namespace kernel {
namespace {
template <typename T, typename M, typename S>
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateTripletMarginLossKernelPtr(const std::string &kernel_name,
const uint32_t &device_id) {
return std::make_unique<cukernel::TripletMarginLossHelperGpuKernel<T, M, S>>(kernel_name, device_id);
}
using TripletMarginLossPtrCreatorFunc =
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
const std::vector<std::pair<KernelAttr, TripletMarginLossPtrCreatorFunc>> kernel_attr = {
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<Complex<float>, float, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat64),
CreateTripletMarginLossKernelPtr<Complex<double>, float, double>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat64),
CreateTripletMarginLossKernelPtr<double, float, double>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<float, float, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16),
CreateTripletMarginLossKernelPtr<half, float, half>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<int16_t, float, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<int32_t, float, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<int64_t, float, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<int8_t, float, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<uint16_t, float, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<uint32_t, float, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<uint64_t, float, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<uint8_t, float, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<Complex<float>, double, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
CreateTripletMarginLossKernelPtr<Complex<double>, double, double>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
CreateTripletMarginLossKernelPtr<double, double, double>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<float, double, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat16),
CreateTripletMarginLossKernelPtr<half, double, half>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<int16_t, double, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<int32_t, double, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<int64_t, double, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<int8_t, double, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<uint16_t, double, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<uint32_t, double, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<uint64_t, double, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32),
CreateTripletMarginLossKernelPtr<uint8_t, double, float>}};
} // namespace
bool TripletMarginLossGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
return false;
}
return true;
}
bool TripletMarginLossGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::TripletMarginLoss>(base_operator);
kernel_name_ = kernel_ptr->name();
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
if (!is_match) {
return false;
}
attr_ptr_->p = kernel_ptr->get_p();
attr_ptr_->swap = kernel_ptr->get_swap();
attr_ptr_->eps = kernel_ptr->get_eps();
attr_ptr_->reduction = kernel_ptr->get_reduction();
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
helper_ptr_->SetKernelParam(attr_ptr_);
Resize(base_operator, inputs, outputs);
return true;
}
int TripletMarginLossGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
for (const auto &input : inputs) {
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
constexpr int64_t kzero = 0;
constexpr int64_t kone = 1;
constexpr int64_t ktwo = 2;
constexpr int64_t kthree = 3;
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<int64_t>> output_shapes;
std::vector<int64_t> inp_shape1 = inputs[kzero]->GetShapeVector();
std::vector<int64_t> inp_shape2 = inputs[kone]->GetShapeVector();
std::vector<int64_t> inp_shape3 = inputs[ktwo]->GetShapeVector();
std::vector<int64_t> inp_shape4 = inputs[kthree]->GetShapeVector();
std::vector<int64_t> out_shape = outputs[kzero]->GetShapeVector();
input_shapes.emplace_back(inp_shape1);
input_shapes.emplace_back(inp_shape2);
input_shapes.emplace_back(inp_shape3);
input_shapes.emplace_back(inp_shape4);
output_shapes.emplace_back(out_shape);
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
return KRET_RESIZE_FAILED;
}
input_size_list_ = helper_ptr_->GetInputSizeList();
output_size_list_ = helper_ptr_->GetOutputSizeList();
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
return KRET_OK;
}
std::vector<KernelAttr> TripletMarginLossGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, TripletMarginLossPtrCreatorFunc> &item) { return item.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, TripletMarginLoss, TripletMarginLossGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_TRIPLET_MARGIN_LOSS_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_TRIPLET_MARGIN_LOSS_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include <functional>
#include <map>
#include "mindspore/core/ops/triplet_margin_loss.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/triplet_margin_loss_helper.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "kernel/kernel.h"
namespace mindspore {
namespace kernel {
template <typename T>
using Complex = mindspore::utils::Complex<T>;
class TripletMarginLossGpuKernelMod : public NativeGpuKernelMod {
public:
TripletMarginLossGpuKernelMod() { attr_ptr_ = std::make_shared<cukernel::TripletMarginLossAttr>(); }
~TripletMarginLossGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
std::shared_ptr<cukernel::TripletMarginLossAttr> attr_ptr_{nullptr};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_TRIPLET_MARGIN_LOSS_GPU_KERNEL_H_

View File

@ -270,6 +270,7 @@ constexpr auto kSideEffectIO = "side_effect_io";
constexpr auto kDeviceType = "device_type";
constexpr auto kExclusive = "exclusive";
constexpr auto kReverse = "reverse";
constexpr auto kSwap = "swap";
constexpr auto kSplitStride = "split_stride";
constexpr auto kExtendTop = "extend_top";
constexpr auto kExtendBottom = "extend_bottom";

View File

@ -77,7 +77,7 @@ TypePtr TripletMarginLossInferType(const PrimitivePtr &primitive, const std::vec
auto op_name = primitive->name();
const std::set<TypePtr> valid_types = {kComplex64, kComplex128, kFloat64, kFloat32, kFloat16, kInt16, kInt32,
kInt64, kInt8, kUInt16, kUInt32, kUInt64, kUInt8};
const std::set<TypePtr> valid_types2 = {kFloat32};
const std::set<TypePtr> valid_types2 = {kFloat32, kFloat64};
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[kInputIndex0]->BuildType());
types.emplace("positive", input_args[kInputIndex1]->BuildType());
@ -87,8 +87,14 @@ TypePtr TripletMarginLossInferType(const PrimitivePtr &primitive, const std::vec
(void)CheckAndConvertUtils::CheckTensorTypeValid("margin", margin, valid_types2, op_name);
auto x_type = input_args[kInputIndex0]->BuildType();
TypePtr output;
if (x_type == kFloat16) {
if (x_type->isa<TensorType>()) {
auto tensor_type = x_type->cast<TensorTypePtr>();
x_type = tensor_type->element();
}
if (IsIdentidityOrSubclass(x_type, kFloat16)) {
output = kFloat16;
} else if (IsIdentidityOrSubclass(x_type, kFloat64) || IsIdentidityOrSubclass(x_type, kComplex128)) {
output = kFloat64;
} else {
output = kFloat32;
}
@ -97,6 +103,16 @@ TypePtr TripletMarginLossInferType(const PrimitivePtr &primitive, const std::vec
} // namespace
MIND_API_OPERATOR_IMPL(TripletMarginLoss, BaseOperator);
void TripletMarginLoss::set_p(const int64_t p) { (void)this->AddAttr(kP, api::MakeValue(p)); }
void TripletMarginLoss::set_eps(const float eps) { (void)this->AddAttr(kEps, api::MakeValue(eps)); }
void TripletMarginLoss::set_swap(const bool swap) { (void)this->AddAttr(kSwap, api::MakeValue(swap)); }
void TripletMarginLoss::set_reduction(const std::string &reduction) {
(void)this->AddAttr(kReduction, api::MakeValue(reduction));
}
int64_t TripletMarginLoss::get_p() const { return GetValue<int64_t>(GetAttr(kP)); }
float TripletMarginLoss::get_eps() const { return GetValue<float>(GetAttr(kEps)); }
bool TripletMarginLoss::get_swap() const { return GetValue<bool>(GetAttr(kSwap)); }
std::string TripletMarginLoss::get_reduction() const { return GetValue<std::string>(GetAttr(kReduction)); }
AbstractBasePtr TripletMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -34,6 +34,16 @@ class MIND_API TripletMarginLoss : public BaseOperator {
TripletMarginLoss() : BaseOperator(kNameTripletMarginLoss) {
InitIOName({"x", "positive", "negative", "margin"}, {"y"});
}
void Init(const int64_t p = 2, const float eps = 1e-6, const bool swap = false,
const std::string &reduction = "mean");
void set_p(const int64_t p);
void set_eps(const float eps);
void set_swap(const bool swap);
void set_reduction(const std::string &reduction);
int64_t get_p() const;
float get_eps() const;
bool get_swap() const;
std::string get_reduction() const;
};
abstract::AbstractBasePtr TripletMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -10069,7 +10069,7 @@ class TripletMarginLoss(Primitive):
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> loss = ops.TripletMarginLoss()

View File

@ -0,0 +1,112 @@
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.ops.operations.nn_ops as ops
import torch
class NetTripletMarginLoss(nn.Cell):
def __init__(self, p=2, swap=False, eps=1e-6, reduction="mean"):
super(NetTripletMarginLoss, self).__init__()
self.triplet_margin_loss = ops.TripletMarginLoss(
p=p, swap=swap, eps=eps, reduction=reduction)
def construct(self, anchor, positive, negative, margin):
return self.triplet_margin_loss(anchor, positive, negative, margin)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_triplet_margin_loss_float64():
"""
Feature: Input type of float64
Description: Input type of [float64, float64, float64, float64].
Expectation: success.
"""
for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]:
context.set_context(mode=mode, device_target="GPU")
data_type = np.float64
anchor_array = np.array([[1.3, 20.5, 5.6]]).astype(data_type)
positive_array = np.array([[2., 10., 1.],
[6., 7., 10.],
[13., 4., 1.],
[0.33, -4, -1.5]]).astype(data_type)
negative_array = np.array([[2., 21., 6.],
[68., 9., 10.],
[131., 25., 16.],
[0.31, -0.14, -16.]]).astype(data_type)
margin = np.float32(2.0)
p = 0
swap = True
reduction = "none"
eps = 1e-5
anchor = Tensor(anchor_array)
positive = Tensor(positive_array)
negative = Tensor(negative_array)
mind_margin = Tensor(margin)
triplet_margin_loss = NetTripletMarginLoss(p=p, swap=swap, reduction=reduction, eps=eps)
output_ms = triplet_margin_loss(anchor, positive, negative, mind_margin)
print(output_ms)
torch_anchor = torch.tensor(anchor_array)
torch_positive = torch.tensor(positive_array)
torch_negative = torch.tensor(negative_array)
torch_loss = torch.nn.TripletMarginLoss(margin=margin, p=p, swap=swap, reduction=reduction, eps=eps)
expect = torch_loss(torch_anchor, torch_positive, torch_negative)
assert np.allclose(output_ms.asnumpy(),
expect.numpy(),
rtol=1e-5,
atol=1e-5,
equal_nan=False)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_triplet_margin_loss_float32():
"""
Feature: Input type of float32
Description: Input type of [float32, float32, float32, float32].
Expectation: success.
"""
for mode in [context.GRAPH_MODE, context.PYNATIVE_MODE]:
context.set_context(mode=mode, device_target="GPU")
data_type = np.float32
anchor_array = np.array([[1.3, 20.5, 5.6],
[3.5, 4.8, 7.2],
[0.2, 0.01, 1],
[4, 4.1, 20]]).astype(data_type)
positive_array = np.array([[2., 10., 1.],
[6., 7., 10.],
[13., 4., 1.],
[0.33, -4, -1.5]]).astype(data_type)
negative_array = np.array([[2., 21., 6.],
[68., 9., 10.],
[131., 25., 16.],
[0.31, -0.14, -16.]]).astype(data_type)
margin = 2.0
p = 1
swap = False
reduction = "none"
eps = 1e-6
anchor = Tensor(anchor_array)
positive = Tensor(positive_array)
negative = Tensor(negative_array)
mind_margin = Tensor(margin)
triplet_margin_loss = NetTripletMarginLoss(p=p, swap=swap, reduction=reduction, eps=eps)
output_ms = triplet_margin_loss(anchor, positive, negative, mind_margin)
torch_anchor = torch.tensor(anchor_array)
torch_positive = torch.tensor(positive_array)
torch_negative = torch.tensor(negative_array)
torch_loss = torch.nn.TripletMarginLoss(margin=margin, p=p, swap=swap, reduction=reduction, eps=eps)
expect = torch_loss(torch_anchor, torch_positive, torch_negative)
assert np.allclose(output_ms.asnumpy(),
expect.numpy(),
rtol=1e-4,
atol=1e-4,
equal_nan=False)