!34164 [feat] [assistant] [ops] [I4ZZTZ] New GPU operator implementation, include LogSpace

Merge pull request !34164 from yvlee/logspace
This commit is contained in:
i-robot 2022-06-27 12:39:22 +00:00 committed by Gitee
commit b7d17ab460
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 369 additions and 18 deletions

View File

@ -0,0 +1,110 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/arrays/logspace_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/logspace_impl.cuh"
namespace mindspore {
namespace kernel {
bool LogSpaceGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::LogSpace>(base_operator);
kernel_name_ = kernel_ptr_->name();
// inputs and outputs should not be empty
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [ float16, float32, float64 ], "
<< "but got: " << kernel_attr << ".";
return false;
}
kernel_func_ = func_list_[index].second;
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
steps_ = kernel_ptr_->get_steps();
base_ = kernel_ptr_->get_base();
if (steps_ < 0) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the value of 'steps' should be larger than 0, "
<< "but got " << steps_;
return false;
}
{
size_t input_size = 2 * unit_size_;
input_size_list_.emplace_back(input_size);
size_t output_size = steps_ * unit_size_;
output_size_list_.emplace_back(output_size);
}
return true;
}
int LogSpaceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
ResetResource();
size_t input_size = 2 * unit_size_;
input_size_list_.emplace_back(input_size);
output_size_list_.emplace_back(steps_ * unit_size_);
return KRET_OK;
}
void LogSpaceGpuKernelMod::ResetResource() noexcept {
input_size_list_.clear();
output_size_list_.clear();
}
template <typename T>
bool LogSpaceGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (steps_ == 0) return true;
auto start = GetDeviceAddress<T>(inputs, kIndex0);
auto end = GetDeviceAddress<T>(inputs, kIndex1);
T *output = GetDeviceAddress<T>(outputs, kIndex0);
T host_start, host_end;
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(&host_start, start, sizeof(T), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"For LogSpace, cudaMemcpy start failed");
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(&host_end, end, sizeof(T), cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"For LogSpace, cudaMemcpy end failed");
T host_add = ((host_end - host_start) / (steps_ == 1 ? steps_ : steps_ - 1));
CalLogSpace(host_start, host_add, steps_, base_, output, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
return true;
}
// fp16, float, double
std::vector<std::pair<KernelAttr, LogSpaceGpuKernelMod::LogSpaceFunc>> LogSpaceGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&LogSpaceGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&LogSpaceGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&LogSpaceGpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> LogSpaceGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, LogSpaceFunc> &item) { return item.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LogSpace, LogSpaceGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,74 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_LOGSPACE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_LOGSPACE_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <memory>
#include <utility>
#include <algorithm>
#include <functional>
#include <map>
#include "mindspore/core/ops/log_space.h"
#include "abstract/utils.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class LogSpaceGpuKernelMod : public NativeGpuKernelMod {
public:
LogSpaceGpuKernelMod() { ResetResource(); }
~LogSpaceGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, workspace, outputs);
};
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
std::vector<KernelAttr> GetOpSupport() override;
void ResetResource() noexcept;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using LogSpaceFunc =
std::function<bool(LogSpaceGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
size_t unit_size_{1};
int64_t steps_{10};
size_t base_{10};
std::optional<bool> is_input_dynamic_shape_{};
void *cuda_stream_{nullptr};
curandGenerator_t curand_generator_{nullptr};
LogSpaceFunc kernel_func_{};
static std::vector<std::pair<KernelAttr, LogSpaceFunc>> func_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_LOGSPACE_GPU_KERNEL_H_

View File

@ -0,0 +1,60 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/logspace_impl.cuh"
#include "include/cuda_runtime.h"
#include "include/cuda_fp16.h"
template <typename T>
__global__ void LogSpaceKernel(const T start, const T add,
const int64_t steps, const size_t base, T *output) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < steps; i += gridDim.x * blockDim.x) {
output[i] = pow(base, start + (add * i));
}
return;
}
template <>
__global__ void LogSpaceKernel(const half start, const half add,
const int64_t steps, const size_t base, half *output) {
float start_float = __half2float(start);
float add_value = __half2float(add);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < steps; i += gridDim.x * blockDim.x) {
output[i] = __float2half(pow(base, start_float + (add_value * i)));
}
return;
}
template <typename T>
void CalLogSpace(const T start, const T add, const int64_t steps, const size_t base, T *output,
const uint32_t &device_id, cudaStream_t cuda_stream) {
LogSpaceKernel<<<CUDA_BLOCKS(device_id, steps), CUDA_THREADS(device_id), 0,
cuda_stream>>>(start, add, steps, base, output);
return;
}
template CUDA_LIB_EXPORT void CalLogSpace<half>(const half start, const half add,
const int64_t steps, const size_t base,
half *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalLogSpace<float>(const float start, const float add,
const int64_t steps, const size_t base,
float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalLogSpace<double>(const double start, const double add,
const int64_t steps, const size_t base,
double *output, const uint32_t &device_id, cudaStream_t cuda_stream);

View File

@ -0,0 +1,24 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOGSPACE_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOGSPACE_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T>
CUDA_LIB_EXPORT void CalLogSpace(const T start, const T end, const int64_t steps, const size_t base, T *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOGSPACE_IMPL_CUH_

View File

@ -29,12 +29,14 @@ abstract::ShapePtr LogSpaceInferShape(const PrimitivePtr &primitive, const std::
auto start_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto start_shape = start_shape_map[kShape];
if (start_shape.size() != 0) {
MS_EXCEPTION(ValueError) << "For LogSpace, the dim of input[start] must be 0.";
MS_EXCEPTION(ValueError) << "For LogSpace, The dim of start must be 0, "
<< "but got " << start_shape.size();
}
auto end_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto end_shape = end_shape_map[kShape];
if (end_shape.size() != 0) {
MS_EXCEPTION(ValueError) << "For LogSpace, the dim of input[end] must be 0.";
MS_EXCEPTION(ValueError) << "For LogSpace, The dim of end must be 0, "
<< "but got " << end_shape.size();
}
int64_t shape_value = GetValue<int64_t>(primitive->GetAttr("steps"));
std::vector<int64_t> state_shape = {shape_value};
@ -43,7 +45,8 @@ abstract::ShapePtr LogSpaceInferShape(const PrimitivePtr &primitive, const std::
TypePtr LogSpaceInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
std::map<std::string, TypePtr> types;
(void)types.emplace("start", input_args[0]->BuildType());
(void)types.emplace("end", input_args[1]->BuildType());
@ -55,8 +58,18 @@ TypePtr LogSpaceInferType(const PrimitivePtr &prim, const std::vector<AbstractBa
return infer_type;
}
} // namespace
void LogSpace::Init(int64_t steps, int64_t base) {
set_steps(steps);
set_base(base);
}
void LogSpace::set_steps(int64_t steps) { (void)this->AddAttr(kSteps, api::MakeValue(steps)); }
void LogSpace::set_base(int64_t base) { (void)this->AddAttr(kBase, api::MakeValue(base)); }
int64_t LogSpace::get_steps() const { return GetValue<int64_t>(GetAttr(kSteps)); }
int64_t LogSpace::get_base() const { return GetValue<int64_t>(GetAttr(kBase)); }
MIND_API_OPERATOR_IMPL(LogSpace, BaseOperator);
AbstractBasePtr LogSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
@ -67,6 +80,7 @@ AbstractBasePtr LogSpaceInfer(const abstract::AnalysisEnginePtr &, const Primiti
auto infer_shape = LogSpaceInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
MIND_API_OPERATOR_IMPL(LogSpace, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(LogSpace, prim::kPrimLogSpace, LogSpaceInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_CORE_OPS_LOG_SPACE_H_
#define MINDSPORE_CORE_OPS_LOG_SPACE_H_
#include <stdint.h>
#include <map>
#include <memory>
#include <set>
@ -37,6 +38,18 @@ class MIND_API LogSpace : public BaseOperator {
LogSpace() : BaseOperator(kNameLogSpace) { InitIOName({"start", "end"}, {"y"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.LogSpace for the inputs.
void Init() const {}
void Init(int64_t steps, int64_t base);
/// \brief Set steps.
void set_steps(int64_t steps);
/// \brief Set base.
void set_base(int64_t base);
/// \return base.
int64_t get_base() const;
/// \return steps.
int64_t get_steps() const;
};
abstract::AbstractBasePtr LogSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);

View File

@ -202,6 +202,7 @@ constexpr auto kSrcT = "src_t";
constexpr auto kStart = "start";
constexpr auto kStepH = "step_h";
constexpr auto kStepW = "step_w";
constexpr auto kSteps = "steps";
constexpr auto kStride = "stride";
constexpr auto kStrides = "strides";
constexpr auto kShapeType = "shape_type";

View File

@ -7432,25 +7432,27 @@ class RightShift(Primitive):
class LogSpace(Primitive):
r"""
Creates a one-dimensional tensor of size steps whose values are evenly spaced from base**start to base**end,
Returns a one-dimensional tensor of size steps whose values are evenly spaced from base**start to base**end,
inclusive, on a logarithmic scale with base.
.. math::s
.. math::
\begin{aligned}
&step = (end - start)/(steps - 1)\\
&output = [base**start,base**(start + (end-start)/(steps-1)),
base**(start + (steps-2)(end-start)/(steps-1)),
... , base**end]
&output = [base^{start}, base^{start + 1 * step}, ... , base^{start + (steps-2) * step}, base^{end}]
\end{aligned}
Args:
steps (int): The steps must be a non-negative integer.
base (int): The base must be a non-negative integer.
dtype (mindspore.dtype): The dtype of output, mindspore.float16 or mindspore.float32.
steps (int): The steps must be a non-negative integer. default: 10
base (int): The base must be a non-negative integer. default: 10
dtype (mindspore.dtype): The dtype of output,
include mindspore.float16, mindspore.float32 or mindspore.float64(for GPU).
Inputs:
- **start** (Tensor) - Start value of interval, with shape of 0-D, dtype is float16 or float32.
- **end** (Tensor) - End value of interval, with shape of 0-D, dtype is float16 or float32.
- **start** (Tensor) - Start value of interval, with shape of 0-D,
dtype is float16, float32 or float64(for GPU).
- **end** (Tensor) - End value of interval, with shape of 0-D,
dtype is float16, float32 or float64(for GPU).
Outputs:
Tensor has the shape as (step, ). Its datatype is set by the attr 'dtype'.
@ -7459,12 +7461,12 @@ class LogSpace(Primitive):
TypeError: If `input` is not a Tensor.
TypeError: If `steps` is not an int.
TypeError: If `base` is not an int.
TypeError: If `dtype` is not mindspore.float16 or mindspore.float32.
TypeError: If `dtype` is not mindspore.float16, mindspore.float32 or mindspore.float64(for GPU).
ValueError: If `steps` is not a non-negative integer.
ValueError: If `base` is not a non-negative integer.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> logspace = ops.LogSpace(steps = 10, base = 10, dtype=mindspore.float32)
@ -7476,14 +7478,14 @@ class LogSpace(Primitive):
"""
@prim_attr_register
def __init__(self, steps, base, dtype):
def __init__(self, steps=10, base=10, dtype=mstype.float32):
"""Initialize Logspace."""
validator.check_value_type("steps", steps, [int], self.name)
validator.check_value_type("base", base, [int], self.name)
validator.check_non_negative_int(steps, "steps", self.name)
validator.check_non_negative_int(base, "base", self.name)
validator.check_value_type("dtype", dtype, [mstype.Type], self.name)
valid_values = (mstype.float16, mstype.float32)
valid_values = (mstype.float16, mstype.float32, mstype.float64)
validator.check_type_name("dtype", dtype, valid_values, self.name)
self.init_prim_io_names(inputs=['start', 'end'], outputs=['y'])

View File

@ -0,0 +1,53 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations.array_ops as opa
class LogSpaceNet(Cell):
def __init__(self, steps=10, base=10, dtype=mstype.float32):
super(LogSpaceNet, self).__init__()
self.ls_op = opa.LogSpace(steps, base, dtype)
def construct(self, start, stop):
output = self.ls_op(start, stop)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_log_space():
"""
Feature: Create a one-dimensional tensor of size steps whose values are spaced from base**start to base**end,
Description: test cases for logspace
Expectation: the result match to numpy
"""
start_np = -5
stop_np = 20
num_np = 20
base_np = 2
result_np = np.logspace(start_np, stop_np, num_np, base=base_np)
start = Tensor(start_np, dtype=mstype.float32)
stop = Tensor(stop_np, dtype=mstype.float32)
net_g = LogSpaceNet(num_np, base_np)
result_g = net_g(start, stop).asnumpy()
assert np.allclose(result_g, result_np, 1e-5, 1e-5)