!43571 [assistant][ops][I4ZZUQ] [I56J61] New GPU operator implementation, include HSVToRGB RGBToHSV SparseApplyCenteredRMSProp
Merge pull request !43571 from 杨鹏康/RGBToHSV
This commit is contained in:
commit
4f1843f2d0
|
@ -34,6 +34,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
// op name. Op which not exists in operator/ops.h, so define it's name here
|
||||
constexpr auto kSparseApplyCenteredRMSPropOpName = "SparseApplyCenteredRMSProp";
|
||||
constexpr auto kAbsOpName = "Abs";
|
||||
constexpr auto kAdamApplyOneAssignOpName = "AdamApplyOneAssign";
|
||||
constexpr auto kAdamApplyOneOpName = "AdamApplyOne";
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HSVTORGB_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HSVTORGB_HELPER_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hsvtorgb_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cukernel {
|
||||
template <typename T, typename S>
|
||||
class HsvToRgbHelperGpuKernel : public GpuKernelHelperBase {
|
||||
public:
|
||||
explicit HsvToRgbHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||
: GpuKernelHelperBase(kernel_name, device_id) {
|
||||
is_null_input_ = false;
|
||||
}
|
||||
|
||||
virtual ~HsvToRgbHelperGpuKernel() = default;
|
||||
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
|
||||
const std::vector<std::vector<int64_t>> &output_shapes) override {
|
||||
constexpr size_t INPUT_NUM = 1;
|
||||
constexpr size_t OUTPUT_NUM = 1;
|
||||
ResetResource();
|
||||
int inp_flag = CalShapesSizeInBytes<T>(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_);
|
||||
if (inp_flag == -1) {
|
||||
return inp_flag;
|
||||
}
|
||||
input_shape_ = input_shapes[0];
|
||||
int out_flag =
|
||||
CalShapesSizeInBytes<S>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
|
||||
if (out_flag == -1) {
|
||||
return out_flag;
|
||||
}
|
||||
is_null_input_ = (inp_flag == 1 || out_flag == 1);
|
||||
return CheckKernelParam();
|
||||
}
|
||||
|
||||
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
|
||||
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t in = input_shape_[0];
|
||||
size_t ic = input_shape_[1];
|
||||
size_t ih = input_shape_[2];
|
||||
size_t iw = input_shape_[3];
|
||||
constexpr int shape_n = 3;
|
||||
|
||||
if (iw != shape_n) {
|
||||
MS_LOG(ERROR) << "For dimension, last dimension must be 3, but got"
|
||||
<< " " << iw << ".\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
T *input_ptr = nullptr;
|
||||
S *output_ptr = nullptr;
|
||||
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &input_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
flag = GetDeviceAddress<S>(output_ptrs, 0, kernel_name_, &output_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
CalHsvtorgb(in * ic * ih * iw, input_ptr, output_ptr, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64_t> input_shape_;
|
||||
bool is_null_input_;
|
||||
};
|
||||
} // namespace cukernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HTORGB_HELPER_H_
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_RGBTOHSV_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_RGBTOHSV_HELPER_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/rgbtohsv_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cukernel {
|
||||
template <typename T, typename S>
|
||||
class RgbToHsvHelperGpuKernel : public GpuKernelHelperBase {
|
||||
public:
|
||||
explicit RgbToHsvHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||
: GpuKernelHelperBase(kernel_name, device_id) {
|
||||
is_null_input_ = false;
|
||||
}
|
||||
|
||||
virtual ~RgbToHsvHelperGpuKernel() = default;
|
||||
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
|
||||
const std::vector<std::vector<int64_t>> &output_shapes) override {
|
||||
constexpr size_t INPUT_NUM = 1;
|
||||
constexpr size_t OUTPUT_NUM = 1;
|
||||
ResetResource();
|
||||
int inp_flag = CalShapesSizeInBytes<T>(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_);
|
||||
if (inp_flag == -1) {
|
||||
return inp_flag;
|
||||
}
|
||||
input_shape_ = input_shapes[0];
|
||||
int out_flag =
|
||||
CalShapesSizeInBytes<S>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
|
||||
if (out_flag == -1) {
|
||||
return out_flag;
|
||||
}
|
||||
is_null_input_ = (inp_flag == 1 || out_flag == 1);
|
||||
return CheckKernelParam();
|
||||
}
|
||||
|
||||
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
|
||||
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
input0_elements_nums_ = 1;
|
||||
size_t n = 0;
|
||||
for (size_t i = 0; i < input_shape_.size(); i++) {
|
||||
input0_elements_nums_ *= input_shape_[i];
|
||||
n++;
|
||||
}
|
||||
|
||||
size_t N = input_shape_[n - 1];
|
||||
constexpr int shape_n = 3;
|
||||
|
||||
if (N != shape_n) {
|
||||
MS_LOG(ERROR) << "For dimension, last dimension must be 3, but got"
|
||||
<< " " << N << ".\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
T *input_ptr = nullptr;
|
||||
S *output_ptr = nullptr;
|
||||
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &input_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
flag = GetDeviceAddress<S>(output_ptrs, 0, kernel_name_, &output_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
CalRgbtohsv(input0_elements_nums_, input_ptr, output_ptr, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64_t> input_shape_;
|
||||
bool is_null_input_;
|
||||
size_t input0_elements_nums_;
|
||||
};
|
||||
} // namespace cukernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_RGBTOHSV_HELPER_H_
|
|
@ -0,0 +1,96 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include "hsvtorgb_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void hsv2rgb(const T h, const T s, const T v, T *r = 0, T *g = 0, T *b = 0) {
|
||||
const T h60 = h * T(6.0);
|
||||
const T h60f = T(floor(static_cast<float>(h60)));
|
||||
const int hi = static_cast<int>(h60f) % 6;
|
||||
const T f = h60 - h60f;
|
||||
const T p = v * (T(1) - s);
|
||||
const T q = v * (T(1) - f * s);
|
||||
const T t = v * (T(1) - (T(1) - f) * s);
|
||||
switch (hi) {
|
||||
case 0:
|
||||
*r = v; *g = t; *b = p;
|
||||
break;
|
||||
case 1:
|
||||
*r = q; *g = v; *b = p;
|
||||
break;
|
||||
case 2:
|
||||
*r = p; *g = v; *b = t;
|
||||
break;
|
||||
case 3:
|
||||
*r = p; *g = q; *b = v;
|
||||
break;
|
||||
case 4:
|
||||
*r = t; *g = p; *b = v;
|
||||
break;
|
||||
case 5:
|
||||
*r = v; *g = p; *b = q;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Hsvtorgb(const size_t input_size, const T *input, T *output) {
|
||||
for (size_t idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
idx < input_size / 3; idx += blockDim.x * gridDim.x) {
|
||||
T r, g, b;
|
||||
hsv2rgb(input[idx * 3], input[idx * 3 + 1], input[idx * 3 + 2], &r, &g, &b);
|
||||
output[idx * 3] = r;
|
||||
output[idx * 3 + 1] = g;
|
||||
output[idx * 3 + 2] = b;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void Hsvtorgb(const size_t input_size, const half *input, half *output) {
|
||||
for (size_t idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
idx < input_size / 3; idx += blockDim.x * gridDim.x) {
|
||||
float r, g, b;
|
||||
hsv2rgb(static_cast<float>(input[idx * 3]),
|
||||
static_cast<float>(input[idx * 3 + 1]),
|
||||
static_cast<float>(input[idx * 3 + 2]), &r, &g, &b);
|
||||
output[idx * 3] = static_cast<half>(r);
|
||||
output[idx * 3 + 1] = static_cast<half>(g);
|
||||
output[idx * 3 + 2] = static_cast<half>(b);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalHsvtorgb(const size_t input_size, const T *input, T *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
Hsvtorgb<<<CUDA_BLOCKS(device_id, input_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(input_size, input, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalHsvtorgb<half>(const size_t input_size, const half *input, half *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalHsvtorgb<float>(const size_t input_size, const float *input, float *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalHsvtorgb<double>(const size_t input_size, const double *input, double *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HSVTORGB_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HSVTORGB_IMPL_CUH_
|
||||
#include "include/cuda_fp16.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
CUDA_LIB_EXPORT void CalHsvtorgbFp16(const size_t input_size, const half *input, half *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
CUDA_LIB_EXPORT void CalHsvtorgbFp32(const size_t input_size, const float *input, float *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
CUDA_LIB_EXPORT void CalHsvtorgbFp64(const size_t input_size, const double *input, double *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalHsvtorgb(const size_t input_size, const T *input,
|
||||
T *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HSVTORGB_IMPL_CUH_
|
|
@ -0,0 +1,106 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include "rgbtohsv_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T Max(T a, T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T Min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T Mod(T a, T b) {
|
||||
return a - b * floor(a/b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void rgb2hsv(const T r, const T g, const T b, T *h, T *s, T *v) {
|
||||
const T M = Max(r, Max(g, b));
|
||||
const T m = Min(r, Min(g, b));
|
||||
const T chroma = M - m;
|
||||
*h = 0.0f, *s = 0.0f;
|
||||
if (chroma > T(0.0f)) {
|
||||
if (M == r) {
|
||||
const T num = (g - b) / chroma;
|
||||
const T sign = copysignf(1.0f, num);
|
||||
*h = ((sign < 0.0f) * 6.0f + sign * Mod(sign * num, T(6.0f))) / 6.0f;
|
||||
} else if (M == g) {
|
||||
*h = ((b - r) / chroma + 2.0f) / 6.0f;
|
||||
} else {
|
||||
*h = ((r - g) / chroma + 4.0f) / 6.0f;
|
||||
}
|
||||
} else {
|
||||
*h = 0.0f;
|
||||
}
|
||||
if (M > 0.0) {
|
||||
*s = chroma / M;
|
||||
} else {
|
||||
*s = 0.0f;
|
||||
}
|
||||
*v = M;
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Rgbtohsv(const size_t input_size, const T *input, T *output) {
|
||||
for (size_t idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
idx < input_size / 3; idx += blockDim.x * gridDim.x) {
|
||||
T h, s, v;
|
||||
rgb2hsv(input[idx * 3], input[idx * 3 + 1], input[idx * 3 + 2], &h, &s, &v);
|
||||
output[idx * 3] = h;
|
||||
output[idx * 3 + 1] = s;
|
||||
output[idx * 3 + 2] = v;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void Rgbtohsv(const size_t input_size, const half *input, half *output) {
|
||||
for (size_t idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
idx < input_size / 3; idx += blockDim.x * gridDim.x) {
|
||||
float h, s, v;
|
||||
rgb2hsv(static_cast<float>(input[idx * 3]),
|
||||
static_cast<float>(input[idx * 3 + 1]),
|
||||
static_cast<float>(input[idx * 3 + 2]), &h, &s, &v);
|
||||
output[idx * 3] = static_cast<half>(h);
|
||||
output[idx * 3 + 1] = static_cast<half>(s);
|
||||
output[idx * 3 + 2] = static_cast<half>(v);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalRgbtohsv(const size_t input_size, const T *input,
|
||||
T *output, const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
Rgbtohsv<<<CUDA_BLOCKS(device_id, input_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(input_size, input, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalRgbtohsv<half>(const size_t input_size, const half *input, half *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalRgbtohsv<float>(const size_t input_size, const float *input, float *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalRgbtohsv<double>(const size_t input_size, const double *input, double *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RGBTOHSV_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RGBTOHSV_IMPL_CUH_
|
||||
#include "include/cuda_fp16.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
CUDA_LIB_EXPORT void CalRgbtohsvFp16(const size_t input_size, const half *input, half *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
CUDA_LIB_EXPORT void CalRgbtohsvFp32(const size_t input_size, const float *input, float *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
CUDA_LIB_EXPORT void CalRgbtohsvFp64(const size_t input_size, const double *input, double *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalRgbtohsv(const size_t input_size, const T *input,
|
||||
T *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RGBTOHSV_IMPL_CUH_
|
|
@ -0,0 +1,298 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_apply_centered_rms_prop_impl.cuh"
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T RsqrtFunc(T x) {
|
||||
return __frsqrt_rn(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ half RsqrtFunc(half x) {
|
||||
return hrsqrt(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ double RsqrtFunc(double x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void SparseApplyCenteredRMSPropUpdate(const size_t size, const size_t indices_size, const bool use_locking,
|
||||
T *learning_rate, T *decay_rate,
|
||||
T *epsilon, T *momentum, const T *gradient, const S *indices,
|
||||
T *variable, T *mean_grad, T *mean_square, T *mom, T *variable_out) {
|
||||
const int64_t inner_size = static_cast<int64_t>(size * sizeof(int64_t) / sizeof(S));
|
||||
const T con1 = static_cast<T>(1);
|
||||
for (int64_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < static_cast<int64_t>(size); pos +=
|
||||
gridDim.x * blockDim.x) {
|
||||
const int64_t index = pos / inner_size;
|
||||
const int64_t inner_pos = pos % inner_size;
|
||||
const int64_t grad_pos = pos;
|
||||
const int64_t cur_pos = indices[index] * inner_size + inner_pos;
|
||||
|
||||
mean_square[cur_pos] = (*decay_rate) * mean_square[cur_pos] + (con1 - (*decay_rate)) *
|
||||
gradient[grad_pos] * gradient[grad_pos];
|
||||
mean_grad[cur_pos] = mean_grad[cur_pos] * (*decay_rate) + gradient[grad_pos] * (con1 - (*decay_rate));
|
||||
const T denom = mean_square[cur_pos] + (*epsilon) - mean_grad[cur_pos] * mean_grad[cur_pos];
|
||||
mom[cur_pos] = (*learning_rate) * gradient[grad_pos] * RsqrtFunc(denom) + mom[cur_pos] * (*momentum);
|
||||
variable_out[cur_pos] = variable[cur_pos] - mom[cur_pos];
|
||||
}
|
||||
}
|
||||
template <typename S>
|
||||
__global__ void SparseApplyCenteredRMSPropUpdate(const size_t size, const size_t indices_size, const bool use_locking,
|
||||
double *learning_rate, double *decay_rate, double *epsilon,
|
||||
double *momentum, const double *gradient, const S *indices,
|
||||
double *variable, double *mean_grad, double *mean_square,
|
||||
double *mom, double *variable_out) {
|
||||
const int64_t inner_size = static_cast<int64_t>(size * sizeof(int64_t) / sizeof(S));
|
||||
const double con1 = static_cast<double>(1);
|
||||
for (int64_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < static_cast<int64_t>(size); pos +=
|
||||
gridDim.x * blockDim.x) {
|
||||
const int64_t index = pos / inner_size;
|
||||
const int64_t inner_pos = pos % inner_size;
|
||||
const int64_t grad_pos = pos;
|
||||
const int64_t cur_pos = indices[index] * inner_size + inner_pos;
|
||||
|
||||
mean_square[cur_pos] = (*decay_rate) * mean_square[cur_pos] + (con1 - (*decay_rate)) *
|
||||
gradient[grad_pos] * gradient[grad_pos];
|
||||
mean_grad[cur_pos] = mean_grad[cur_pos] * (*decay_rate) + gradient[grad_pos] * (con1 - (*decay_rate));
|
||||
const double denom = mean_square[cur_pos] + (*epsilon) - mean_grad[cur_pos] * mean_grad[cur_pos];
|
||||
mom[cur_pos] = (*learning_rate) * gradient[grad_pos] * RsqrtFunc(denom) + mom[cur_pos] * (*momentum);
|
||||
variable_out[cur_pos] = variable[cur_pos] - mom[cur_pos];
|
||||
}
|
||||
}
|
||||
template <typename S>
|
||||
__global__ void SparseApplyCenteredRMSPropUpdate(const size_t size, const size_t indices_size, const bool use_locking,
|
||||
half *learning_rate, half *decay_rate, half *epsilon,
|
||||
half *momentum, const half *gradient, const S *indices, half *variable,
|
||||
half *mean_grad, half *mean_square, half *mom, half *variable_out) {
|
||||
// const int64_t inner_size = static_cast<int64_t>(size / indices_size);
|
||||
const int64_t inner_size = static_cast<int64_t>(size * sizeof(int64_t) / sizeof(S));
|
||||
const float con1 = static_cast<float>(1);
|
||||
for (int64_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < static_cast<int64_t>(size); pos +=
|
||||
gridDim.x * blockDim.x) {
|
||||
const int64_t index = pos / inner_size;
|
||||
const int64_t inner_pos = pos % inner_size;
|
||||
const int64_t grad_pos = pos;
|
||||
const int64_t cur_pos = indices[index] * inner_size + inner_pos;
|
||||
|
||||
mean_square[cur_pos] = static_cast<float>(*decay_rate) * static_cast<float>(mean_square[cur_pos]) +
|
||||
static_cast<float>(con1 - static_cast<float>(*decay_rate)) *
|
||||
static_cast<float>(gradient[grad_pos]) * static_cast<float>(gradient[grad_pos]);
|
||||
mean_grad[cur_pos] = static_cast<float>(mean_grad[cur_pos]) * static_cast<float>(*decay_rate) +
|
||||
static_cast<float>(gradient[grad_pos]) * (con1 - static_cast<float>(*decay_rate));
|
||||
const float denom = static_cast<float>(mean_square[cur_pos]) + static_cast<float>(*epsilon) -
|
||||
static_cast<float>(mean_grad[cur_pos]) * static_cast<float>(mean_grad[cur_pos]);
|
||||
mom[cur_pos] = static_cast<float>(*learning_rate) * static_cast<float>(gradient[grad_pos]) *
|
||||
static_cast<float>(RsqrtFunc(denom)) + static_cast<float>(mom[cur_pos]) *
|
||||
static_cast<float>(*momentum);
|
||||
variable_out[cur_pos] = static_cast<float>(static_cast<float>(variable[cur_pos]) -
|
||||
static_cast<float>(mom[cur_pos]));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalSparseApplyCenteredRMSProp(const size_t size, const size_t indices_size, const bool use_locking,
|
||||
T *learning_rate, T *decay_rate,
|
||||
T *epsilon, T *momentum, const T *gradient, const S *indices,
|
||||
T *variable, T *mean_grad, T *mean_square, T *mom, T *variable_out,
|
||||
cudaStream_t cuda_stream) {
|
||||
SparseApplyCenteredRMSPropUpdate<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
|
||||
size, indices_size, use_locking, learning_rate, decay_rate, epsilon, momentum, gradient, indices, variable,
|
||||
mean_grad, mean_square, mom, variable_out);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<half, int32_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
half *learning_rate, half *decay_rate, half *epsilon,
|
||||
half *momentum, const half *gradient,
|
||||
const int32_t *indices, half *variable,
|
||||
half *mean_grad, half *mean_square, half *mom,
|
||||
half *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<float, int32_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
float *learning_rate, float *decay_rate,
|
||||
float *epsilon, float *momentum,
|
||||
const float *gradient,
|
||||
const int32_t *indices, float *variable,
|
||||
float *mean_grad, float *mean_square, float *mom,
|
||||
float *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<double, int32_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
double *learning_rate, double *decay_rate,
|
||||
double *epsilon, double *momentum,
|
||||
const double *gradient,
|
||||
const int32_t *indices, double *variable,
|
||||
double *mean_grad, double *mean_square, double *mom,
|
||||
double *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<int8_t, int32_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
int8_t *learning_rate, int8_t *decay_rate,
|
||||
int8_t *epsilon,
|
||||
int8_t *momentum, const int8_t *gradient,
|
||||
const int32_t *indices, int8_t *variable,
|
||||
int8_t *mean_grad, int8_t *mean_square, int8_t *mom,
|
||||
int8_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<int16_t, int32_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
int16_t *learning_rate, int16_t *decay_rate,
|
||||
int16_t *epsilon,
|
||||
int16_t *momentum, const int16_t *gradient,
|
||||
const int32_t *indices, int16_t *variable,
|
||||
int16_t *mean_grad, int16_t *mean_square, int16_t *mom,
|
||||
int16_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<int32_t, int32_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
int32_t *learning_rate, int32_t *decay_rate,
|
||||
int32_t *epsilon,
|
||||
int32_t *momentum, const int32_t *gradient,
|
||||
const int32_t *indices, int32_t *variable,
|
||||
int32_t *mean_grad, int32_t *mean_square, int32_t *mom,
|
||||
int32_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<int64_t, int32_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
int64_t *learning_rate, int64_t *decay_rate,
|
||||
int64_t *epsilon,
|
||||
int64_t *momentum, const int64_t *gradient,
|
||||
const int32_t *indices, int64_t *variable,
|
||||
int64_t *mean_grad, int64_t *mean_square, int64_t *mom,
|
||||
int64_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<uint8_t, int32_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
uint8_t *learning_rate, uint8_t *decay_rate,
|
||||
uint8_t *epsilon,
|
||||
uint8_t *momentum, const uint8_t *gradient,
|
||||
const int32_t *indices, uint8_t *variable,
|
||||
uint8_t *mean_grad, uint8_t *mean_square, uint8_t *mom,
|
||||
uint8_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<uint16_t, int32_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
uint16_t *learning_rate, uint16_t *decay_rate,
|
||||
uint16_t *epsilon, uint16_t *momentum,
|
||||
const uint16_t *gradient, const int32_t *indices,
|
||||
uint16_t *variable, uint16_t *mean_grad,
|
||||
uint16_t *mean_square, uint16_t *mom,
|
||||
uint16_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<uint32_t, int32_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
uint32_t *learning_rate, uint32_t *decay_rate,
|
||||
uint32_t *epsilon, uint32_t *momentum,
|
||||
const uint32_t *gradient, const int32_t *indices,
|
||||
uint32_t *variable, uint32_t *mean_grad,
|
||||
uint32_t *mean_square, uint32_t *mom,
|
||||
uint32_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<uint64_t, int32_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
uint64_t *learning_rate, uint64_t *decay_rate,
|
||||
uint64_t *epsilon,
|
||||
uint64_t *momentum, const uint64_t *gradient,
|
||||
const int32_t *indices, uint64_t *variable,
|
||||
uint64_t *mean_grad, uint64_t *mean_square,
|
||||
uint64_t *mom,
|
||||
uint64_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<half, int64_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
half *learning_rate, half *decay_rate, half *epsilon,
|
||||
half *momentum, const half *gradient,
|
||||
const int64_t *indices, half *variable,
|
||||
half *mean_grad, half *mean_square, half *mom,
|
||||
half *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<float, int64_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
float *learning_rate, float *decay_rate,
|
||||
float *epsilon, float *momentum, const float *gradient,
|
||||
const int64_t *indices, float *variable,
|
||||
float *mean_grad, float *mean_square, float *mom,
|
||||
float *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<double, int64_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
double *learning_rate, double *decay_rate,
|
||||
double *epsilon, double *momentum,
|
||||
const double *gradient,
|
||||
const int64_t *indices, double *variable,
|
||||
double *mean_grad, double *mean_square, double *mom,
|
||||
double *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<int8_t, int64_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
int8_t *learning_rate, int8_t *decay_rate,
|
||||
int8_t *epsilon,
|
||||
int8_t *momentum, const int8_t *gradient,
|
||||
const int64_t *indices, int8_t *variable,
|
||||
int8_t *mean_grad, int8_t *mean_square, int8_t *mom,
|
||||
int8_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<int16_t, int64_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
int16_t *learning_rate, int16_t *decay_rate,
|
||||
int16_t *epsilon,
|
||||
int16_t *momentum, const int16_t *gradient,
|
||||
const int64_t *indices, int16_t *variable,
|
||||
int16_t *mean_grad, int16_t *mean_square, int16_t *mom,
|
||||
int16_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<int32_t, int64_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
int32_t *learning_rate, int32_t *decay_rate,
|
||||
int32_t *epsilon,
|
||||
int32_t *momentum, const int32_t *gradient,
|
||||
const int64_t *indices, int32_t *variable,
|
||||
int32_t *mean_grad, int32_t *mean_square, int32_t *mom,
|
||||
int32_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<int64_t, int64_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
int64_t *learning_rate, int64_t *decay_rate,
|
||||
int64_t *epsilon,
|
||||
int64_t *momentum, const int64_t *gradient,
|
||||
const int64_t *indices, int64_t *variable,
|
||||
int64_t *mean_grad, int64_t *mean_square, int64_t *mom,
|
||||
int64_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<uint8_t, int64_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
uint8_t *learning_rate, uint8_t *decay_rate,
|
||||
uint8_t *epsilon,
|
||||
uint8_t *momentum, const uint8_t *gradient,
|
||||
const int64_t *indices, uint8_t *variable,
|
||||
uint8_t *mean_grad, uint8_t *mean_square, uint8_t *mom,
|
||||
uint8_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<uint16_t, int64_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
uint16_t *learning_rate, uint16_t *decay_rate,
|
||||
uint16_t *epsilon,
|
||||
uint16_t *momentum, const uint16_t *gradient,
|
||||
const int64_t *indices, uint16_t *variable,
|
||||
uint16_t *mean_grad, uint16_t *mean_square,
|
||||
uint16_t *mom,
|
||||
uint16_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<uint32_t, int64_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
uint32_t *learning_rate, uint32_t *decay_rate,
|
||||
uint32_t *epsilon,
|
||||
uint32_t *momentum, const uint32_t *gradient,
|
||||
const int64_t *indices, uint32_t *variable,
|
||||
uint32_t *mean_grad, uint32_t *mean_square,
|
||||
uint32_t *mom,
|
||||
uint32_t *variable_out, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp<uint64_t, int64_t>(const size_t size,
|
||||
const size_t indices_size, const bool use_locking,
|
||||
uint64_t *learning_rate, uint64_t *decay_rate,
|
||||
uint64_t *epsilon,
|
||||
uint64_t *momentum, const uint64_t *gradient,
|
||||
const int64_t *indices, uint64_t *variable,
|
||||
uint64_t *mean_grad, uint64_t *mean_square,
|
||||
uint64_t *mom,
|
||||
uint64_t *variable_out, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_APPLY_CENTERED_RMS_PROP_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_APPLY_CENTERED_RMS_PROP_IMPL_CUH_
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
template <typename T, typename S>
|
||||
CUDA_LIB_EXPORT void CalSparseApplyCenteredRMSProp(const size_t size, const size_t indices_size,
|
||||
const bool use_locking, T *learning_rate, T *decay_rate, T *epsilon,
|
||||
T *momentum, const T *gradient, const S *indices, T *variable,
|
||||
T *mean_grad, T *mean_square, T *mom, T *variable_out,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_APPLY_CENTERED_RMS_PROP_IMPL_CUH_
|
|
@ -0,0 +1,489 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindspore/core/ops/sparse_apply_centered_rms_prop.h"
|
||||
#include "plugin/device/gpu/kernel/nn/sparse_apply_centered_rms_prop_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kSparseApplyCenteredRMSPropInputsNum = 10;
|
||||
constexpr size_t kVarIndex = 0;
|
||||
constexpr size_t kMgIndex = 1;
|
||||
constexpr size_t kMsIndex = 2;
|
||||
constexpr size_t kMomIndex = 3;
|
||||
constexpr size_t kLrIndex = 4;
|
||||
constexpr size_t kRhoIndex = 5;
|
||||
constexpr size_t kMomentumIndex = 6;
|
||||
constexpr size_t kEpsilonIndex = 7;
|
||||
constexpr size_t kGradIndex = 8;
|
||||
constexpr size_t kIndicesIndex = 9;
|
||||
} // namespace
|
||||
|
||||
bool SparseApplyCenteredRMSPropGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
if (kernel_name_ != prim::kPrimSparseApplyCenteredRMSProp->name()) {
|
||||
MS_LOG(ERROR) << "For 'SparseApplyCenteredRMSProp', the kernel name must be 'SparseApplyCenteredRMSProp', but got "
|
||||
<< kernel_name_;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::SparseApplyCenteredRMSProp>(base_operator);
|
||||
MS_EXCEPTION_IF_NULL(kernel_ptr);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(ERROR) << "SparseApplyCenteredRMSProp ops failed!";
|
||||
return false;
|
||||
}
|
||||
use_locking_ = kernel_ptr->get_use_locking();
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int SparseApplyCenteredRMSPropGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
if (input_size_list_.size() != kSparseApplyCenteredRMSPropInputsNum) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' input size must be equal 10 but got " << input_size_list_.size();
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
std::vector<int64_t> var_shape = inputs[kVarIndex]->GetShapeVector();
|
||||
std::vector<int64_t> mg_shape = inputs[kMgIndex]->GetShapeVector();
|
||||
std::vector<int64_t> ms_shape = inputs[kMsIndex]->GetShapeVector();
|
||||
std::vector<int64_t> mom_shape = inputs[kMomIndex]->GetShapeVector();
|
||||
std::vector<int64_t> lr_shape = inputs[kLrIndex]->GetShapeVector();
|
||||
std::vector<int64_t> rho_shape = inputs[kRhoIndex]->GetShapeVector();
|
||||
std::vector<int64_t> momentum_shape = inputs[kMomentumIndex]->GetShapeVector();
|
||||
std::vector<int64_t> epsilon_shape = inputs[kEpsilonIndex]->GetShapeVector();
|
||||
std::vector<int64_t> grad_shape = inputs[kGradIndex]->GetShapeVector();
|
||||
std::vector<int64_t> indices_shape = inputs[kIndicesIndex]->GetShapeVector();
|
||||
if (!lr_shape.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', lr is not a scalar.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (!rho_shape.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', rho is not a scalar.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (!momentum_shape.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', momentum is not a scalar.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (!epsilon_shape.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', epsilon is not a scalar.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (var_shape.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the dimension of 'var' must be at least 1-D, but got scalar or None.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (!IsSameShape(var_shape, mg_shape)) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the shape of 'mg' must be the same as the shape of 'var', "
|
||||
"but got the shape of 'mg': "
|
||||
<< Vector2Str(mg_shape) << " and the shape of 'var': " << Vector2Str(var_shape);
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (var_shape.size() != ms_shape.size()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the dimension of 'ms' must be the same as the dimension of "
|
||||
"'var', but got the dimension of 'ms': "
|
||||
<< ms_shape.size() << " and the dimension of 'var': " << var_shape.size() << ".";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (var_shape.size() != mom_shape.size()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the dimension of 'mom' must be the same as the dimension of "
|
||||
"'var', but got the dimension of 'mom': "
|
||||
<< mom_shape.size() << " and the dimension of 'var': " << var_shape.size() << ".";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
for (size_t i = 1; i < var_shape.size(); ++i) {
|
||||
if (var_shape[i] != grad_shape[i]) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of 'var' and 'grad' must be equal in dimension i=" << i
|
||||
<< ", but got 'var_shape[i]': " << var_shape[i] << " and 'grad_shape[i]': " << grad_shape[i];
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
}
|
||||
if (indices_shape[0] != grad_shape[0]) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the size of 'grad' must be the same as the size of "
|
||||
"'indicies' ";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (indices_shape.size() != 1) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'indices' must be a 1-D vector, but got "
|
||||
<< indices_shape.size() << "-D.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
// auto indices_size = indices_shape[0];
|
||||
auto indices_size = 1;
|
||||
for (size_t i = 0; i < indices_shape.size(); i++) {
|
||||
indices_size *= indices_shape[i];
|
||||
}
|
||||
if (grad_shape[0] != SizeToLong(indices_size)) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the first dimension value of 'grad' must be equal to "
|
||||
"the first dimension value of 'indices', but got the first dimension value of 'grad': "
|
||||
<< grad_shape[0] << ", and the first dimension value of 'indices': " << indices_size;
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
input_elements_ = input_size_list_[0] / unit_size_;
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
bool SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto var = reinterpret_cast<T *>(inputs[kVarIndex]->addr);
|
||||
auto mg = reinterpret_cast<T *>(inputs[kMgIndex]->addr);
|
||||
auto ms = reinterpret_cast<T *>(inputs[kMsIndex]->addr);
|
||||
auto mom = reinterpret_cast<T *>(inputs[kMomIndex]->addr);
|
||||
auto lr = reinterpret_cast<T *>(inputs[kLrIndex]->addr);
|
||||
auto rho = reinterpret_cast<T *>(inputs[kRhoIndex]->addr);
|
||||
auto momentum = reinterpret_cast<T *>(inputs[kMomentumIndex]->addr);
|
||||
auto epsilon = reinterpret_cast<T *>(inputs[kEpsilonIndex]->addr);
|
||||
auto grad = reinterpret_cast<T *>(inputs[kGradIndex]->addr);
|
||||
auto indices = reinterpret_cast<S *>(inputs[kIndicesIndex]->addr);
|
||||
auto var_out = reinterpret_cast<T *>(outputs[kVarIndex]->addr);
|
||||
|
||||
CalSparseApplyCenteredRMSProp(input_elements_, sizeof(S) / sizeof(int), use_locking_, lr, rho, epsilon, momentum,
|
||||
grad, indices, var, mg, ms, mom, var_out, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, SparseApplyCenteredRMSPropGpuKernelMod::SparseApplyCenteredRMSPropFunc>>
|
||||
SparseApplyCenteredRMSPropGpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<half, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<int8_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<int16_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<uint8_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<uint16_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<uint32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<uint64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<double, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<uint8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&SparseApplyCenteredRMSPropGpuKernelMod::LaunchKernel<uint64_t, int64_t>},
|
||||
};
|
||||
|
||||
std::vector<KernelAttr> SparseApplyCenteredRMSPropGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, SparseApplyCenteredRMSPropFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SparseApplyCenteredRMSProp, SparseApplyCenteredRMSPropGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_SPARSE_APPLY_CENTERED_RMS_PROP_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_SPARSE_APPLY_CENTERED_RMS_PROP_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <cstdio>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_apply_centered_rms_prop_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SparseApplyCenteredRMSPropGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
SparseApplyCenteredRMSPropGpuKernelMod() = default;
|
||||
~SparseApplyCenteredRMSPropGpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
|
||||
MS_EXCEPTION_IF_NULL(cuda_stream);
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
|
||||
cuda_stream_ = cuda_stream;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T, typename S>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
using SparseApplyCenteredRMSPropFunc =
|
||||
std::function<bool(SparseApplyCenteredRMSPropGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, SparseApplyCenteredRMSPropFunc>> func_list_;
|
||||
SparseApplyCenteredRMSPropFunc kernel_func_;
|
||||
|
||||
void *cuda_stream_{nullptr};
|
||||
bool is_null_input_{false};
|
||||
bool use_locking_;
|
||||
int unit_size_;
|
||||
size_t input_elements_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_SPARSE_APPLY_CENTERED_RMS_PROP_GPU_KERNEL_H_
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/other/hsv_to_rgb_gpu_kernel.h"
|
||||
#include <utility>
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
template <typename T, typename S>
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateHsvToRgbKernelPtr(const std::string &kernel_name,
|
||||
const uint32_t &device_id) {
|
||||
return std::make_unique<cukernel::HsvToRgbHelperGpuKernel<T, S>>(kernel_name, device_id);
|
||||
}
|
||||
using HsvToRgbPtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, HsvToRgbPtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateHsvToRgbKernelPtr<half, half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateHsvToRgbKernelPtr<float, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateHsvToRgbKernelPtr<double, double>}};
|
||||
} // namespace
|
||||
|
||||
bool HsvtorgbGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HsvtorgbGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::make_shared<ops::HSVToRGB>(base_operator->GetPrim());
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
return false;
|
||||
}
|
||||
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int HsvtorgbGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
std::vector<int64_t> inp_shape = inputs[0]->GetShapeVector();
|
||||
std::vector<int64_t> out_shape = outputs[0]->GetShapeVector();
|
||||
input_shapes.emplace_back(inp_shape);
|
||||
output_shapes.emplace_back(out_shape);
|
||||
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
input_size_list_ = helper_ptr_->GetInputSizeList();
|
||||
output_size_list_ = helper_ptr_->GetOutputSizeList();
|
||||
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> HsvtorgbGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, HsvToRgbPtrCreatorFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, HSVToRGB, HsvtorgbGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_HSVTORGB_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_HSVTORGB_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
#include "mindspore/core/ops/hsv_to_rgb.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/hsvtorgb_helper.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class HsvtorgbGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
HsvtorgbGpuKernelMod() {}
|
||||
~HsvtorgbGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_HSVTORGB_GPU_KERNEL_H_
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/other/rgb_to_hsv_gpu_kernel.h"
|
||||
#include <utility>
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
template <typename T, typename S>
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateRgbToHsvKernelPtr(const std::string &kernel_name,
|
||||
const uint32_t &device_id) {
|
||||
return std::make_unique<cukernel::RgbToHsvHelperGpuKernel<T, S>>(kernel_name, device_id);
|
||||
}
|
||||
using RgbToHsvPtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, RgbToHsvPtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateRgbToHsvKernelPtr<half, half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateRgbToHsvKernelPtr<float, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateRgbToHsvKernelPtr<double, double>}};
|
||||
} // namespace
|
||||
|
||||
bool RgbtohsvGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RgbtohsvGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::make_shared<ops::RGBToHSV>(base_operator->GetPrim());
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
return false;
|
||||
}
|
||||
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int RgbtohsvGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
std::vector<int64_t> inp_shape = inputs[0]->GetShapeVector();
|
||||
std::vector<int64_t> out_shape = outputs[0]->GetShapeVector();
|
||||
input_shapes.emplace_back(inp_shape);
|
||||
output_shapes.emplace_back(out_shape);
|
||||
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
input_size_list_ = helper_ptr_->GetInputSizeList();
|
||||
output_size_list_ = helper_ptr_->GetOutputSizeList();
|
||||
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> RgbtohsvGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, RgbToHsvPtrCreatorFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, RGBToHSV, RgbtohsvGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_RGBTOHSV_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_RGBTOHSV_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include "mindspore/core/ops/rgb_to_hsv.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/rgbtohsv_helper.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class RgbtohsvGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
RgbtohsvGpuKernelMod() {}
|
||||
~RgbtohsvGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_RGBTOHSV_GPU_KERNEL_H_
|
|
@ -24,14 +24,20 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr RGBToHSVInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
||||
const int64_t input_dims = SizeToLong(input_shape.size());
|
||||
const int64_t input_last_dims = input_shape.cend()[-1];
|
||||
const int64_t numberofRGB_3 = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("last dimension of input 'images'", input_last_dims, kEqual, numberofRGB_3,
|
||||
kNameRGBToHSV);
|
||||
if (input_dims < 1) {
|
||||
MS_LOG(EXCEPTION) << "For " << primitive->name() << ", the dimension of input 'images' must be 1-D or higher rank.";
|
||||
auto input_shape_ptr = input_args[0]->BuildShape();
|
||||
if (IsDynamicRank(input_shape)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
|
||||
}
|
||||
if (!input_shape_ptr->IsDynamic()) {
|
||||
const int64_t input_dims = SizeToLong(input_shape.size());
|
||||
const int64_t input_last_dims = input_shape.cend()[-1];
|
||||
const int64_t numberofRGB_3 = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("last dimension of input 'images'", input_last_dims, kEqual, numberofRGB_3,
|
||||
kNameRGBToHSV);
|
||||
if (input_dims < 1) {
|
||||
MS_LOG(EXCEPTION) << "For " << primitive->name()
|
||||
<< ", the dimension of input 'images' must be 1-D or higher rank.";
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(input_shape);
|
||||
}
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -42,16 +44,15 @@ abstract::ShapePtr SparseApplyCenteredRMSPropInferShape(const PrimitivePtr &prim
|
|||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[9]->BuildShape())[kShape];
|
||||
|
||||
const int64_t scalar_shape = 0;
|
||||
std::vector<ShapeVector> scalar_shapes = {lr_shape, rho_shape, momentum_shape, epsilon_shape};
|
||||
auto is_dynamic_scalar = std::any_of(scalar_shapes.begin(), scalar_shapes.end(), IsDynamic);
|
||||
if (!is_dynamic_scalar) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape.size(), kEqual, scalar_shape, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("rho_shape size", rho_shape.size(), kEqual, scalar_shape, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("momentum_shape size", momentum_shape.size(), kEqual, scalar_shape,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("epsilon_shape size", epsilon_shape.size(), kEqual, scalar_shape,
|
||||
prim_name);
|
||||
if (IsDynamicRank(var_shape) || IsDynamicRank(mg_shape) || IsDynamicRank(ms_shape) || IsDynamicRank(mom_shape) ||
|
||||
IsDynamicRank(grad_shape)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape.size(), kEqual, scalar_shape, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("rho_shape size", rho_shape.size(), kEqual, scalar_shape, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("momentum_shape size", momentum_shape.size(), kEqual, scalar_shape,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("epsilon_shape size", epsilon_shape.size(), kEqual, scalar_shape, prim_name);
|
||||
|
||||
std::vector<ShapeVector> tensor_shapes = {var_shape, mg_shape, ms_shape, mom_shape};
|
||||
auto is_dynamic_tensor = std::any_of(tensor_shapes.begin(), tensor_shapes.end(), IsDynamic);
|
||||
|
@ -64,7 +65,6 @@ abstract::ShapePtr SparseApplyCenteredRMSPropInferShape(const PrimitivePtr &prim
|
|||
CheckAndConvertUtils::Check(elem.first, elem.second, kEqual, var_shape, prim_name);
|
||||
}
|
||||
}
|
||||
|
||||
// Var dimension must be equal or greater than 1.
|
||||
(void)CheckAndConvertUtils::CheckInteger("var dimension", SizeToLong(var_shape.size()), kGreaterEqual, 1, prim_name);
|
||||
// Indices must be rank 1.
|
||||
|
@ -106,7 +106,6 @@ TypePtr SparseApplyCenteredRMSPropInferType(const PrimitivePtr &primitive,
|
|||
auto epsilon = input_args[7]->BuildType();
|
||||
auto grad = input_args[8]->BuildType();
|
||||
auto indices = input_args[9]->BuildType();
|
||||
|
||||
std::map<std::string, TypePtr> args;
|
||||
(void)args.emplace("var", var);
|
||||
(void)args.emplace("ms", mg);
|
||||
|
@ -125,6 +124,13 @@ TypePtr SparseApplyCenteredRMSPropInferType(const PrimitivePtr &primitive,
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void SparseApplyCenteredRMSProp::Init(bool use_locking) { set_use_locking(use_locking); }
|
||||
|
||||
void SparseApplyCenteredRMSProp::set_use_locking(bool use_locking) {
|
||||
(void)AddAttr(kUseLocking, api::MakeValue(use_locking));
|
||||
}
|
||||
bool SparseApplyCenteredRMSProp::get_use_locking() { return GetValue<bool>(GetAttr(kUseLocking)); }
|
||||
|
||||
AbstractBasePtr SparseApplyCenteredRMSPropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -134,7 +140,9 @@ AbstractBasePtr SparseApplyCenteredRMSPropInfer(const abstract::AnalysisEnginePt
|
|||
auto infer_shape = SparseApplyCenteredRMSPropInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SparseApplyCenteredRMSProp, BaseOperator);
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SparseApplyCenteredRMSProp, prim::kPrimSparseApplyCenteredRMSProp,
|
||||
SparseApplyCenteredRMSPropInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -35,6 +35,12 @@ class MIND_API SparseApplyCenteredRMSProp : public BaseOperator {
|
|||
SparseApplyCenteredRMSProp() : BaseOperator(kNameSparseApplyCenteredRMSProp) {
|
||||
InitIOName({"var", "mg", "ms", "mom", "lr", "rho", "momentum", "epsilon", "grad", "indices"}, {"var"});
|
||||
}
|
||||
|
||||
void Init(bool use_locking = false);
|
||||
|
||||
void set_use_locking(bool use_locking);
|
||||
|
||||
bool get_use_locking();
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr SparseApplyCenteredRMSPropInfer(const abstract::AnalysisEnginePtr &,
|
||||
|
|
|
@ -487,12 +487,14 @@ class NonMaxSuppressionWithOverlaps(Primitive):
|
|||
|
||||
class HSVToRGB(Primitive):
|
||||
"""
|
||||
Convert one or more images from HSV to RGB. The format of the image(s) should be NHWC.
|
||||
Convert one or more images from HSV to RGB.
|
||||
Outputs a tensor of the same shape as the images tensor, containing the HSV value of the pixels.
|
||||
The output is only well defined if the value in images are in [0,1].
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input image must be a 4-D tensor of shape [batch, image_height, image_width, channel].
|
||||
Number of channel must be 3.
|
||||
Types allowed: float16, float32, float64.
|
||||
**x** (Tensor) - The input image must be a 4-D tensor of shape [batch, image_height, image_width, channel].
|
||||
Number of channel must be 3.
|
||||
Types allowed: float16, float32, float64.
|
||||
Outputs:
|
||||
A 4-D tensor of shape [batch, image_height, image_width, channel] with same type of input.
|
||||
|
||||
|
@ -503,7 +505,7 @@ class HSVToRGB(Primitive):
|
|||
ValueError: If the last dimension of `x` is not equal to 3.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> image = np.array([0.5, 0.5, 0.5]).astype(np.float32).reshape([1, 1, 1, 3])
|
||||
|
@ -611,7 +613,7 @@ class RGBToHSV(Primitive):
|
|||
ValueError: If the last value of shape of `images` is not 3.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> images = np.array([0.25, 0.5, 0.5]).astype(np.float32).reshape([1, 1, 1, 3])
|
||||
|
|
|
@ -8778,7 +8778,7 @@ class SparseApplyCenteredRMSProp(Primitive):
|
|||
ValueError: If shape of `grad` is not same as shape of `var` except first dimension.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.hsvtorgb = P.HSVToRGB()
|
||||
|
||||
def construct(self, x):
|
||||
return self.hsvtorgb(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_float16():
|
||||
"""
|
||||
Feature: None
|
||||
Description: basic test float16
|
||||
Expectation: just test
|
||||
"""
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([0.5, 0.5, 0.5]).astype(np.float16).reshape([1, 1, 1, 3])
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
expected = np.array([0.25, 0.5, 0.5]).astype(np.float16).reshape([1, 1, 1, 3])
|
||||
assert np.allclose(output.asnumpy(), expected, 1e-3, 1e-3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_float32():
|
||||
"""
|
||||
Feature: None
|
||||
Description: basic test float32
|
||||
Expectation: just test
|
||||
"""
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([0.5, 0.5, 0.5]).astype(np.float32).reshape([1, 1, 1, 3])
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
expected = np.array([0.25, 0.5, 0.5]).astype(np.float32).reshape([1, 1, 1, 3])
|
||||
assert np.allclose(output.asnumpy(), expected, 1e-4, 1e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_float64():
|
||||
"""
|
||||
Feature: None
|
||||
Description: basic test float64
|
||||
Expectation: just test
|
||||
"""
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([0.5, 0.5, 0.5]).astype(np.float64).reshape([1, 1, 1, 3])
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
expected = np.array([0.25, 0.5, 0.5]).astype(np.float64).reshape([1, 1, 1, 3])
|
||||
assert np.allclose(output.asnumpy(), expected, 1e-5, 1e-5)
|
|
@ -0,0 +1,69 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
import mindspore.ops.operations.image_ops as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.rgbtohsv = P.RGBToHSV()
|
||||
|
||||
def construct(self, x):
|
||||
return self.rgbtohsv(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_float16():
|
||||
"""
|
||||
Feature: None
|
||||
Description: basic test float16
|
||||
Expectation: just test
|
||||
"""
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([0.25, 0.5, 0.5]).astype(np.float16).reshape([1, 1, 1, 3])
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
expected = np.array([0.5, 0.5, 0.5]).astype(np.float16).reshape([1, 1, 1, 3])
|
||||
assert np.allclose(output.asnumpy(), expected, 1e-3, 1e-3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_float32():
|
||||
"""
|
||||
Feature: None
|
||||
Description: basic test float32
|
||||
Expectation: just test
|
||||
"""
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([0.25, 0.5, 0.5]).astype(np.float32).reshape([1, 1, 1, 3])
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
expected = np.array([0.5, 0.5, 0.5]).astype(np.float32).reshape([1, 1, 1, 3])
|
||||
assert np.allclose(output.asnumpy(), expected, 1e-4, 1e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_float64():
|
||||
"""
|
||||
Feature: None
|
||||
Description: basic test float64
|
||||
Expectation: just test
|
||||
"""
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([0.25, 0.5, 0.5]).astype(np.float64).reshape([1, 1, 1, 3])
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
expected = np.array([0.5, 0.5, 0.5]).astype(np.float64).reshape([1, 1, 1, 3])
|
||||
assert np.allclose(output.asnumpy(), expected, 1e-5, 1e-5)
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
import mindspore.ops.operations.nn_ops as P
|
||||
|
||||
|
||||
class SparseApplyCenteredRMSPropNet(nn.Cell):
|
||||
def __init__(self, use_locking=False):
|
||||
super(SparseApplyCenteredRMSPropNet, self).__init__()
|
||||
self.sparse_apply_centered_rms_prop = P.SparseApplyCenteredRMSProp(use_locking=False)
|
||||
|
||||
def construct(self, var, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices):
|
||||
out = self.sparse_apply_centered_rms_prop(var, mg, ms, mom, lr, rho, momentum, epsilon, grad, indices)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_apply_centered_rms_prop_graph_1():
|
||||
"""
|
||||
Feature: Test whether the output of Var calculated by mindspore and tensorflow are equal.
|
||||
Description: Inputs are Tensors in shape [2, 2]for mutable tensors, value for scalar and shape [2] for indices.
|
||||
Expectation: Success.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
var = Parameter(Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)), name="var")
|
||||
mg = Parameter(Tensor(np.array([[0.1, 0.3], [0.1, 0.5]]).astype(np.float32)), name="mg")
|
||||
ms = Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), name="ms")
|
||||
mom = Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), name="mom")
|
||||
lr = Tensor(0.001, mstype.float32)
|
||||
rho = Tensor(1e-10, mstype.float32)
|
||||
momentum = Tensor(0.001, mstype.float32)
|
||||
epsilon = Tensor(0.01, mstype.float32)
|
||||
grad = Parameter(Tensor(np.array([[0.3, 0.4], [0.1, 0.2]]).astype(np.float32)))
|
||||
indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
sparse_apply_centered_rms_prop_net = SparseApplyCenteredRMSPropNet(use_locking=False)
|
||||
sparse_apply_centered_rms_prop_output = sparse_apply_centered_rms_prop_net(var, mg, ms, mom, lr, rho, \
|
||||
momentum, epsilon, grad, indices)
|
||||
sparse_apply_centered_rms_prop_expected_output = np.array([[0.5968, 0.3959], [0.0989, 0.4978]]).astype(np.float32)
|
||||
|
||||
print(sparse_apply_centered_rms_prop_output)
|
||||
print(sparse_apply_centered_rms_prop_expected_output)
|
||||
assert np.allclose(sparse_apply_centered_rms_prop_output.asnumpy(), \
|
||||
sparse_apply_centered_rms_prop_expected_output, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_apply_centered_rms_prop_graph_2():
|
||||
"""
|
||||
Feature: Test whether the output of Var calculated by mindspore and tensorflow are equal.
|
||||
Description: Inputs are Tensors in shape [2, 2]for mutable tensors, value for scalar and shape [2] for indices.
|
||||
Expectation: Success.
|
||||
"""
|
||||
var = Parameter(Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)), name="var")
|
||||
mg = Parameter(Tensor(np.array([[0.1, 0.3], [0.1, 0.5]]).astype(np.float32)), name="mg")
|
||||
ms = Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), name="ms")
|
||||
mom = Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), name="mom")
|
||||
lr = Tensor(0.001, mstype.float32)
|
||||
rho = Tensor(1e-10, mstype.float32)
|
||||
momentum = Tensor(0.001, mstype.float32)
|
||||
epsilon = Tensor(0.01, mstype.float32)
|
||||
grad = Parameter(Tensor(np.array([[0.3, 0.4], [0.1, 0.2]]).astype(np.float32)))
|
||||
indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
sparse_apply_centered_rms_prop_net = SparseApplyCenteredRMSPropNet(use_locking=False)
|
||||
sparse_apply_centered_rms_prop_output = sparse_apply_centered_rms_prop_net(var, mg, ms, mom, lr, rho, \
|
||||
momentum, epsilon, grad, indices)
|
||||
sparse_apply_centered_rms_prop_expected_output = np.array([[0.5968, 0.3959], [0.0989, 0.4978]]).astype(np.float32)
|
||||
|
||||
print(sparse_apply_centered_rms_prop_output)
|
||||
print(sparse_apply_centered_rms_prop_expected_output)
|
||||
assert np.allclose(sparse_apply_centered_rms_prop_output.asnumpy(), \
|
||||
sparse_apply_centered_rms_prop_expected_output, rtol=1e-3)
|
Loading…
Reference in New Issue