forked from mindspore-Ecosystem/mindspore
!34164 [feat] [assistant] [ops] [I4ZZTZ] New GPU operator implementation, include LogSpace
Merge pull request !34164 from yvlee/logspace
This commit is contained in:
commit
b7d17ab460
|
@ -0,0 +1,110 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/gpu/kernel/arrays/logspace_gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/logspace_impl.cuh"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
bool LogSpaceGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::LogSpace>(base_operator);
|
||||
kernel_name_ = kernel_ptr_->name();
|
||||
// inputs and outputs should not be empty
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [ float16, float32, float64 ], "
|
||||
<< "but got: " << kernel_attr << ".";
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
steps_ = kernel_ptr_->get_steps();
|
||||
base_ = kernel_ptr_->get_base();
|
||||
if (steps_ < 0) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the value of 'steps' should be larger than 0, "
|
||||
<< "but got " << steps_;
|
||||
return false;
|
||||
}
|
||||
{
|
||||
size_t input_size = 2 * unit_size_;
|
||||
input_size_list_.emplace_back(input_size);
|
||||
size_t output_size = steps_ * unit_size_;
|
||||
output_size_list_.emplace_back(output_size);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
int LogSpaceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
ResetResource();
|
||||
size_t input_size = 2 * unit_size_;
|
||||
input_size_list_.emplace_back(input_size);
|
||||
output_size_list_.emplace_back(steps_ * unit_size_);
|
||||
return KRET_OK;
|
||||
}
|
||||
void LogSpaceGpuKernelMod::ResetResource() noexcept {
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
}
|
||||
template <typename T>
|
||||
bool LogSpaceGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (steps_ == 0) return true;
|
||||
auto start = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
auto end = GetDeviceAddress<T>(inputs, kIndex1);
|
||||
T *output = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
T host_start, host_end;
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(&host_start, start, sizeof(T), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"For LogSpace, cudaMemcpy start failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
|
||||
cudaMemcpyAsync(&host_end, end, sizeof(T), cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"For LogSpace, cudaMemcpy end failed");
|
||||
T host_add = ((host_end - host_start) / (steps_ == 1 ? steps_ : steps_ - 1));
|
||||
CalLogSpace(host_start, host_add, steps_, base_, output, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
return true;
|
||||
}
|
||||
|
||||
// fp16, float, double
|
||||
std::vector<std::pair<KernelAttr, LogSpaceGpuKernelMod::LogSpaceFunc>> LogSpaceGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&LogSpaceGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&LogSpaceGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&LogSpaceGpuKernelMod::LaunchKernel<double>}};
|
||||
|
||||
std::vector<KernelAttr> LogSpaceGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, LogSpaceFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LogSpace, LogSpaceGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_LOGSPACE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_LOGSPACE_GPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include "mindspore/core/ops/log_space.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class LogSpaceGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
LogSpaceGpuKernelMod() { ResetResource(); }
|
||||
~LogSpaceGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
|
||||
cuda_stream_ = cuda_stream;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
};
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
void ResetResource() noexcept;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using LogSpaceFunc =
|
||||
std::function<bool(LogSpaceGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
size_t unit_size_{1};
|
||||
int64_t steps_{10};
|
||||
size_t base_{10};
|
||||
std::optional<bool> is_input_dynamic_shape_{};
|
||||
void *cuda_stream_{nullptr};
|
||||
curandGenerator_t curand_generator_{nullptr};
|
||||
LogSpaceFunc kernel_func_{};
|
||||
static std::vector<std::pair<KernelAttr, LogSpaceFunc>> func_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_LOGSPACE_GPU_KERNEL_H_
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/logspace_impl.cuh"
|
||||
#include "include/cuda_runtime.h"
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
|
||||
template <typename T>
|
||||
__global__ void LogSpaceKernel(const T start, const T add,
|
||||
const int64_t steps, const size_t base, T *output) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < steps; i += gridDim.x * blockDim.x) {
|
||||
output[i] = pow(base, start + (add * i));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void LogSpaceKernel(const half start, const half add,
|
||||
const int64_t steps, const size_t base, half *output) {
|
||||
float start_float = __half2float(start);
|
||||
float add_value = __half2float(add);
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < steps; i += gridDim.x * blockDim.x) {
|
||||
output[i] = __float2half(pow(base, start_float + (add_value * i)));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalLogSpace(const T start, const T add, const int64_t steps, const size_t base, T *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
LogSpaceKernel<<<CUDA_BLOCKS(device_id, steps), CUDA_THREADS(device_id), 0,
|
||||
cuda_stream>>>(start, add, steps, base, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalLogSpace<half>(const half start, const half add,
|
||||
const int64_t steps, const size_t base,
|
||||
half *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalLogSpace<float>(const float start, const float add,
|
||||
const int64_t steps, const size_t base,
|
||||
float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalLogSpace<double>(const double start, const double add,
|
||||
const int64_t steps, const size_t base,
|
||||
double *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,24 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOGSPACE_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOGSPACE_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalLogSpace(const T start, const T end, const int64_t steps, const size_t base, T *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOGSPACE_IMPL_CUH_
|
|
@ -29,12 +29,14 @@ abstract::ShapePtr LogSpaceInferShape(const PrimitivePtr &primitive, const std::
|
|||
auto start_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto start_shape = start_shape_map[kShape];
|
||||
if (start_shape.size() != 0) {
|
||||
MS_EXCEPTION(ValueError) << "For LogSpace, the dim of input[start] must be 0.";
|
||||
MS_EXCEPTION(ValueError) << "For LogSpace, The dim of start must be 0, "
|
||||
<< "but got " << start_shape.size();
|
||||
}
|
||||
auto end_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto end_shape = end_shape_map[kShape];
|
||||
if (end_shape.size() != 0) {
|
||||
MS_EXCEPTION(ValueError) << "For LogSpace, the dim of input[end] must be 0.";
|
||||
MS_EXCEPTION(ValueError) << "For LogSpace, The dim of end must be 0, "
|
||||
<< "but got " << end_shape.size();
|
||||
}
|
||||
int64_t shape_value = GetValue<int64_t>(primitive->GetAttr("steps"));
|
||||
std::vector<int64_t> state_shape = {shape_value};
|
||||
|
@ -43,7 +45,8 @@ abstract::ShapePtr LogSpaceInferShape(const PrimitivePtr &primitive, const std::
|
|||
|
||||
TypePtr LogSpaceInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("start", input_args[0]->BuildType());
|
||||
(void)types.emplace("end", input_args[1]->BuildType());
|
||||
|
@ -55,8 +58,18 @@ TypePtr LogSpaceInferType(const PrimitivePtr &prim, const std::vector<AbstractBa
|
|||
return infer_type;
|
||||
}
|
||||
} // namespace
|
||||
void LogSpace::Init(int64_t steps, int64_t base) {
|
||||
set_steps(steps);
|
||||
set_base(base);
|
||||
}
|
||||
|
||||
void LogSpace::set_steps(int64_t steps) { (void)this->AddAttr(kSteps, api::MakeValue(steps)); }
|
||||
void LogSpace::set_base(int64_t base) { (void)this->AddAttr(kBase, api::MakeValue(base)); }
|
||||
|
||||
int64_t LogSpace::get_steps() const { return GetValue<int64_t>(GetAttr(kSteps)); }
|
||||
|
||||
int64_t LogSpace::get_base() const { return GetValue<int64_t>(GetAttr(kBase)); }
|
||||
|
||||
MIND_API_OPERATOR_IMPL(LogSpace, BaseOperator);
|
||||
AbstractBasePtr LogSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -67,6 +80,7 @@ AbstractBasePtr LogSpaceInfer(const abstract::AnalysisEnginePtr &, const Primiti
|
|||
auto infer_shape = LogSpaceInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(LogSpace, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LogSpace, prim::kPrimLogSpace, LogSpaceInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#ifndef MINDSPORE_CORE_OPS_LOG_SPACE_H_
|
||||
#define MINDSPORE_CORE_OPS_LOG_SPACE_H_
|
||||
#include <stdint.h>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
@ -37,6 +38,18 @@ class MIND_API LogSpace : public BaseOperator {
|
|||
LogSpace() : BaseOperator(kNameLogSpace) { InitIOName({"start", "end"}, {"y"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.LogSpace for the inputs.
|
||||
void Init() const {}
|
||||
|
||||
void Init(int64_t steps, int64_t base);
|
||||
/// \brief Set steps.
|
||||
void set_steps(int64_t steps);
|
||||
/// \brief Set base.
|
||||
void set_base(int64_t base);
|
||||
|
||||
/// \return base.
|
||||
int64_t get_base() const;
|
||||
|
||||
/// \return steps.
|
||||
int64_t get_steps() const;
|
||||
};
|
||||
abstract::AbstractBasePtr LogSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
|
|
@ -202,6 +202,7 @@ constexpr auto kSrcT = "src_t";
|
|||
constexpr auto kStart = "start";
|
||||
constexpr auto kStepH = "step_h";
|
||||
constexpr auto kStepW = "step_w";
|
||||
constexpr auto kSteps = "steps";
|
||||
constexpr auto kStride = "stride";
|
||||
constexpr auto kStrides = "strides";
|
||||
constexpr auto kShapeType = "shape_type";
|
||||
|
|
|
@ -7432,25 +7432,27 @@ class RightShift(Primitive):
|
|||
|
||||
class LogSpace(Primitive):
|
||||
r"""
|
||||
Creates a one-dimensional tensor of size steps whose values are evenly spaced from base**start to base**end,
|
||||
Returns a one-dimensional tensor of size steps whose values are evenly spaced from base**start to base**end,
|
||||
inclusive, on a logarithmic scale with base.
|
||||
|
||||
.. math::s
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&step = (end - start)/(steps - 1)\\
|
||||
&output = [base**start,base**(start + (end-start)/(steps-1)),
|
||||
base**(start + (steps-2)(end-start)/(steps-1)),
|
||||
... , base**end]
|
||||
&output = [base^{start}, base^{start + 1 * step}, ... , base^{start + (steps-2) * step}, base^{end}]
|
||||
\end{aligned}
|
||||
|
||||
Args:
|
||||
steps (int): The steps must be a non-negative integer.
|
||||
base (int): The base must be a non-negative integer.
|
||||
dtype (mindspore.dtype): The dtype of output, mindspore.float16 or mindspore.float32.
|
||||
steps (int): The steps must be a non-negative integer. default: 10
|
||||
base (int): The base must be a non-negative integer. default: 10
|
||||
dtype (mindspore.dtype): The dtype of output,
|
||||
include mindspore.float16, mindspore.float32 or mindspore.float64(for GPU).
|
||||
|
||||
|
||||
Inputs:
|
||||
- **start** (Tensor) - Start value of interval, with shape of 0-D, dtype is float16 or float32.
|
||||
- **end** (Tensor) - End value of interval, with shape of 0-D, dtype is float16 or float32.
|
||||
- **start** (Tensor) - Start value of interval, with shape of 0-D,
|
||||
dtype is float16, float32 or float64(for GPU).
|
||||
- **end** (Tensor) - End value of interval, with shape of 0-D,
|
||||
dtype is float16, float32 or float64(for GPU).
|
||||
|
||||
Outputs:
|
||||
Tensor has the shape as (step, ). Its datatype is set by the attr 'dtype'.
|
||||
|
@ -7459,12 +7461,12 @@ class LogSpace(Primitive):
|
|||
TypeError: If `input` is not a Tensor.
|
||||
TypeError: If `steps` is not an int.
|
||||
TypeError: If `base` is not an int.
|
||||
TypeError: If `dtype` is not mindspore.float16 or mindspore.float32.
|
||||
TypeError: If `dtype` is not mindspore.float16, mindspore.float32 or mindspore.float64(for GPU).
|
||||
ValueError: If `steps` is not a non-negative integer.
|
||||
ValueError: If `base` is not a non-negative integer.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> logspace = ops.LogSpace(steps = 10, base = 10, dtype=mindspore.float32)
|
||||
|
@ -7476,14 +7478,14 @@ class LogSpace(Primitive):
|
|||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, steps, base, dtype):
|
||||
def __init__(self, steps=10, base=10, dtype=mstype.float32):
|
||||
"""Initialize Logspace."""
|
||||
validator.check_value_type("steps", steps, [int], self.name)
|
||||
validator.check_value_type("base", base, [int], self.name)
|
||||
validator.check_non_negative_int(steps, "steps", self.name)
|
||||
validator.check_non_negative_int(base, "base", self.name)
|
||||
validator.check_value_type("dtype", dtype, [mstype.Type], self.name)
|
||||
valid_values = (mstype.float16, mstype.float32)
|
||||
valid_values = (mstype.float16, mstype.float32, mstype.float64)
|
||||
validator.check_type_name("dtype", dtype, valid_values, self.name)
|
||||
self.init_prim_io_names(inputs=['start', 'end'], outputs=['y'])
|
||||
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops.operations.array_ops as opa
|
||||
|
||||
|
||||
class LogSpaceNet(Cell):
|
||||
def __init__(self, steps=10, base=10, dtype=mstype.float32):
|
||||
super(LogSpaceNet, self).__init__()
|
||||
self.ls_op = opa.LogSpace(steps, base, dtype)
|
||||
|
||||
def construct(self, start, stop):
|
||||
output = self.ls_op(start, stop)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_log_space():
|
||||
"""
|
||||
Feature: Create a one-dimensional tensor of size steps whose values are spaced from base**start to base**end,
|
||||
Description: test cases for logspace
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
start_np = -5
|
||||
stop_np = 20
|
||||
num_np = 20
|
||||
base_np = 2
|
||||
result_np = np.logspace(start_np, stop_np, num_np, base=base_np)
|
||||
start = Tensor(start_np, dtype=mstype.float32)
|
||||
stop = Tensor(stop_np, dtype=mstype.float32)
|
||||
net_g = LogSpaceNet(num_np, base_np)
|
||||
result_g = net_g(start, stop).asnumpy()
|
||||
assert np.allclose(result_g, result_np, 1e-5, 1e-5)
|
Loading…
Reference in New Issue