forked from mindspore-Ecosystem/mindspore
parent
8ce39575c7
commit
f6f99d6bd5
|
@ -15,11 +15,11 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "plugin/device/cpu/kernel/fill_v2_cpu_kernel.h"
|
#include "plugin/device/cpu/kernel/fill_v2_cpu_kernel.h"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <complex>
|
||||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -31,159 +31,76 @@ constexpr size_t kFillV2OutputsNum = 1;
|
||||||
|
|
||||||
bool FillV2CpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
bool FillV2CpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
const std::vector<KernelTensorPtr> &outputs) {
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(base_operator);
|
||||||
kernel_name_ = base_operator->GetPrim()->name();
|
kernel_name_ = base_operator->GetPrim()->name();
|
||||||
input1_dtype_ = inputs[0]->GetDtype();
|
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||||
output_dtype_ = outputs[0]->GetDtype();
|
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
|
if (!is_match) {
|
||||||
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
kernel_func_ = func_list_[index].second;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
int FillV2CpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
template <typename T>
|
||||||
const std::vector<KernelTensorPtr> &outputs,
|
bool FillV2CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
const std::vector<kernel::AddressPtr> &workspace,
|
||||||
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
|
|
||||||
if (ret != KRET_OK) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
output_shape_ = outputs[0]->GetDeviceShapeAdaptively();
|
|
||||||
return KRET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool FillV2CpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
|
||||||
const std::vector<kernel::AddressPtr> &outputs) {
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
// Check the number of input and output
|
const auto output = outputs[kIndex0];
|
||||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kFillV2InputsNum, kernel_name_);
|
auto *output_data = reinterpret_cast<T *>(output->addr);
|
||||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kFillV2OutputsNum, kernel_name_);
|
auto *value_data = reinterpret_cast<T *>(inputs[kIndex1]->addr);
|
||||||
|
size_t lens = static_cast<size_t>(output->size / sizeof(T));
|
||||||
// Get the shape of the output based on the first input
|
|
||||||
std::vector<int64_t> dims;
|
|
||||||
switch (input1_dtype_) {
|
|
||||||
case (kNumberTypeInt32):
|
|
||||||
CalculateDims<int32_t>(inputs[0], &dims);
|
|
||||||
break;
|
|
||||||
case (kNumberTypeInt64):
|
|
||||||
CalculateDims<int64_t>(inputs[0], &dims);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
MS_LOG(EXCEPTION) << "the datatype of the input1 not support, support datatype: int32, int64.";
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check output shape
|
|
||||||
auto output = outputs[0];
|
|
||||||
std::vector<int64_t> output_new_shape_;
|
|
||||||
auto num = output_shape_.size();
|
|
||||||
for (size_t i = 0; i < num; i++) {
|
|
||||||
auto element = output_shape_[i];
|
|
||||||
output_new_shape_.emplace_back(element);
|
|
||||||
}
|
|
||||||
if (output_new_shape_ != dims) {
|
|
||||||
MS_LOG(EXCEPTION) << "the shape of output is error, the data of the input1 not match the shape of the output.";
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill according to the different data types of the output
|
|
||||||
auto value = inputs[1];
|
|
||||||
switch (output_dtype_) {
|
|
||||||
case (kNumberTypeBool):
|
|
||||||
LaunchKernel<bool>(&output, value);
|
|
||||||
break;
|
|
||||||
case (kNumberTypeInt8):
|
|
||||||
LaunchKernel<int8_t>(&output, value);
|
|
||||||
break;
|
|
||||||
case (kNumberTypeInt16):
|
|
||||||
LaunchKernel<int16_t>(&output, value);
|
|
||||||
break;
|
|
||||||
case (kNumberTypeInt32):
|
|
||||||
LaunchKernel<int32_t>(&output, value);
|
|
||||||
break;
|
|
||||||
case (kNumberTypeInt64):
|
|
||||||
LaunchKernel<int64_t>(&output, value);
|
|
||||||
break;
|
|
||||||
case (kNumberTypeUInt8):
|
|
||||||
LaunchKernel<uint8_t>(&output, value);
|
|
||||||
break;
|
|
||||||
case (kNumberTypeUInt16):
|
|
||||||
LaunchKernel<uint16_t>(&output, value);
|
|
||||||
break;
|
|
||||||
case (kNumberTypeUInt32):
|
|
||||||
LaunchKernel<uint32_t>(&output, value);
|
|
||||||
break;
|
|
||||||
case (kNumberTypeUInt64):
|
|
||||||
LaunchKernel<uint64_t>(&output, value);
|
|
||||||
break;
|
|
||||||
case (kNumberTypeFloat16):
|
|
||||||
LaunchKernel<float16>(&output, value);
|
|
||||||
break;
|
|
||||||
case (kNumberTypeFloat32):
|
|
||||||
LaunchKernel<float>(&output, value);
|
|
||||||
break;
|
|
||||||
case (kNumberTypeFloat64):
|
|
||||||
LaunchKernel<double>(&output, value);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
MS_LOG(EXCEPTION) << "the datatype of the input2 not support, support datatype: "
|
|
||||||
"bool, int8, int16, int32, int64, uint8, uint16, uint32, "
|
|
||||||
"uint64, float16, float32, float64.";
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void FillV2CpuKernelMod::CalculateDims(const AddressPtr &input, std::vector<int64_t> *dims) const {
|
|
||||||
MS_EXCEPTION_IF_NULL(input);
|
|
||||||
auto *input_data = reinterpret_cast<T *>(input->addr);
|
|
||||||
size_t data_num = input->size / sizeof(T);
|
|
||||||
for (size_t i = 0; i < data_num; i++) {
|
|
||||||
auto dim = static_cast<int64_t>(input_data[i]);
|
|
||||||
if (dim < 0) {
|
|
||||||
MS_LOG(EXCEPTION) << "the data of the input1 must all be greater than 0, there is a negative value in input1.";
|
|
||||||
}
|
|
||||||
if (dim == 0) {
|
|
||||||
MS_LOG(EXCEPTION) << "the data of the input1 must all be greater than 0, there is a zero value in input1.";
|
|
||||||
}
|
|
||||||
(*dims).emplace_back(dim);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void FillV2CpuKernelMod::LaunchKernel(AddressPtr *output, const AddressPtr &value) {
|
|
||||||
auto *output_data = reinterpret_cast<T *>((*output)->addr);
|
|
||||||
auto *value_data = reinterpret_cast<T *>((value->addr));
|
|
||||||
size_t lens = static_cast<size_t>((*output)->size / sizeof(T));
|
|
||||||
auto task = [output_data, value_data](const size_t start, const size_t end) {
|
auto task = [output_data, value_data](const size_t start, const size_t end) {
|
||||||
for (size_t i = start; i < end; i++) {
|
for (size_t i = start; i < end; i++) {
|
||||||
output_data[i] = *value_data;
|
output_data[i] = *value_data;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_);
|
ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define FILL_V2_CPU_REG(MS_T, MS_S, T) \
|
||||||
|
KernelAttr().AddInputAttr(MS_T).AddInputAttr(MS_S).AddOutputAttr(MS_S), &FillV2CpuKernelMod::LaunchKernel<T>
|
||||||
|
|
||||||
|
std::vector<std::pair<KernelAttr, FillV2CpuKernelMod::FillV2LaunchFunc>> FillV2CpuKernelMod::func_list_ = {
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeBool, bool)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeInt8, int8_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeInt16, int16_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeInt32, int32_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeInt64, int64_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeUInt8, uint8_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeUInt16, uint16_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeUInt32, uint32_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeUInt64, uint64_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeFloat16, float16)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeFloat32, float)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeFloat64, double)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeComplex64, std::complex<float>)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeComplex128, std::complex<double>)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeBool, bool)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt8, int8_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt16, int16_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt64, int32_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt64, int64_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeUInt8, uint8_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeUInt16, uint16_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeUInt32, uint32_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeUInt64, uint64_t)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeFloat16, float16)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeFloat32, float)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeFloat64, double)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeComplex64, std::complex<float>)},
|
||||||
|
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeComplex128, std::complex<double>)}};
|
||||||
|
|
||||||
std::vector<KernelAttr> FillV2CpuKernelMod::GetOpSupport() {
|
std::vector<KernelAttr> FillV2CpuKernelMod::GetOpSupport() {
|
||||||
static std::vector<KernelAttr> kernel_attr_list = {
|
std::vector<KernelAttr> support_list;
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
(void)std::transform(
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
[](const std::pair<KernelAttr, FillV2CpuKernelMod::FillV2LaunchFunc> &pair) { return pair.first; });
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
return support_list;
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)};
|
|
||||||
return kernel_attr_list;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, FillV2, FillV2CpuKernelMod);
|
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, FillV2, FillV2CpuKernelMod);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||||
#include "plugin/factory/ms_factory.h"
|
#include "plugin/factory/ms_factory.h"
|
||||||
|
|
||||||
|
@ -32,24 +33,23 @@ class FillV2CpuKernelMod : public NativeCpuKernelMod {
|
||||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
const std::vector<KernelTensorPtr> &outputs) override;
|
const std::vector<KernelTensorPtr> &outputs) override;
|
||||||
|
|
||||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
|
||||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs) override;
|
const std::vector<AddressPtr> &outputs) override {
|
||||||
|
return kernel_func_(this, inputs, workspace, outputs);
|
||||||
template <typename T>
|
}
|
||||||
void LaunchKernel(AddressPtr *output, const AddressPtr &value);
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void CalculateDims(const AddressPtr &input, std::vector<int64_t> *dims) const;
|
|
||||||
|
|
||||||
|
protected:
|
||||||
std::vector<KernelAttr> GetOpSupport() override;
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TypeId output_dtype_{kTypeUnknown};
|
template <typename T>
|
||||||
TypeId input1_dtype_{kTypeUnknown};
|
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
ShapeVector output_shape_;
|
const std::vector<AddressPtr> &outputs);
|
||||||
|
|
||||||
|
using FillV2LaunchFunc = std::function<bool(FillV2CpuKernelMod *, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
|
||||||
|
static std::vector<std::pair<KernelAttr, FillV2LaunchFunc>> func_list_;
|
||||||
|
FillV2LaunchFunc kernel_func_;
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,129 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "plugin/device/gpu/kernel/arrays/fill_v2_gpu_kernel.h"
|
||||||
|
#include <functional>
|
||||||
|
#include <utility>
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "mindspore/core/abstract/utils.h"
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_v2_impl.cuh"
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||||
|
#include "kernel/common_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
|
constexpr int kFillV2InputsNum = 2;
|
||||||
|
constexpr int kFillV2OutputsNum = 1;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool FillV2GpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(base_operator);
|
||||||
|
kernel_name_ = base_operator->name();
|
||||||
|
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||||
|
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
|
if (!is_match) {
|
||||||
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
kernel_func_ = func_list_[index].second;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int FillV2GpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||||
|
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
output_shape_ = outputs.at(kIndex0)->GetShapeVector();
|
||||||
|
output_size_ = SizeToLong(SizeOf(output_shape_));
|
||||||
|
return KRET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename DataType>
|
||||||
|
bool FillV2GpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||||
|
if (output_size_ == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kFillV2InputsNum, kernel_name_);
|
||||||
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kFillV2OutputsNum, kernel_name_);
|
||||||
|
|
||||||
|
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||||
|
|
||||||
|
DataType *input_ptr = GetDeviceAddress<DataType>(inputs, kIndex1);
|
||||||
|
MS_ERROR_IF_NULL_W_RET_VAL(input_ptr, false);
|
||||||
|
|
||||||
|
DataType *output_ptr = GetDeviceAddress<DataType>(outputs, kIndex0);
|
||||||
|
MS_ERROR_IF_NULL_W_RET_VAL(output_ptr, false);
|
||||||
|
|
||||||
|
FillV2(output_size_, input_ptr, output_ptr, device_id_, cuda_stream_);
|
||||||
|
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaGetLastError(), "FillV2 kernel failed.");
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define FILL_V2_GPU_REG(MS_T, MS_S, T) \
|
||||||
|
KernelAttr().AddInputAttr(MS_T).AddInputAttr(MS_S).AddOutputAttr(MS_S), &FillV2GpuKernelMod::LaunchKernel<T>
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using Complex = mindspore::utils::Complex<T>;
|
||||||
|
|
||||||
|
std::vector<std::pair<KernelAttr, FillV2GpuKernelMod::FillV2LaunchFunc>> FillV2GpuKernelMod::func_list_ = {
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeBool, bool)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeInt8, int8_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeInt16, int16_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeInt32, int32_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeInt64, int64_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeUInt8, uint8_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeUInt16, uint16_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeUInt32, uint32_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeUInt64, uint64_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeFloat16, half)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeFloat32, float)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeFloat64, double)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeComplex64, Complex<float>)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeComplex128, Complex<double>)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeBool, bool)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt8, int8_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt16, int16_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt64, int32_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt64, int64_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeUInt8, uint8_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeUInt16, uint16_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeUInt32, uint32_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeUInt64, uint64_t)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeFloat16, half)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeFloat32, float)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeFloat64, double)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeComplex64, Complex<float>)},
|
||||||
|
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeComplex128, Complex<double>)}};
|
||||||
|
|
||||||
|
std::vector<KernelAttr> FillV2GpuKernelMod::GetOpSupport() {
|
||||||
|
std::vector<KernelAttr> support_list;
|
||||||
|
(void)std::transform(
|
||||||
|
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||||
|
[](const std::pair<KernelAttr, FillV2GpuKernelMod::FillV2LaunchFunc> &pair) { return pair.first; });
|
||||||
|
return support_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, FillV2, FillV2GpuKernelMod);
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,65 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_FILL_V2_GPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_FILL_V2_GPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <map>
|
||||||
|
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||||
|
#include "plugin/factory/ms_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class FillV2GpuKernelMod : public NativeGpuKernelMod {
|
||||||
|
public:
|
||||||
|
FillV2GpuKernelMod() = default;
|
||||||
|
~FillV2GpuKernelMod() override = default;
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
|
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) override;
|
||||||
|
|
||||||
|
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename DataType>
|
||||||
|
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||||
|
|
||||||
|
using FillV2LaunchFunc =
|
||||||
|
std::function<bool(FillV2GpuKernelMod *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &, void *)>;
|
||||||
|
static std::vector<std::pair<KernelAttr, FillV2LaunchFunc>> func_list_;
|
||||||
|
FillV2LaunchFunc kernel_func_;
|
||||||
|
cudaStream_t cuda_stream_;
|
||||||
|
std::vector<int64_t> output_shape_{};
|
||||||
|
int64_t output_size_{0};
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_FILL_V2_GPU_KERNEL_H_
|
|
@ -0,0 +1,64 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_v2_impl.cuh"
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using Complex = mindspore::utils::Complex<T>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void FillV2Kernel(const int64_t output_size, const T *input, T *output) {
|
||||||
|
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size; pos += blockDim.x * gridDim.x) {
|
||||||
|
output[pos] = input[0];
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void FillV2(const int64_t output_size, const T *input, T *output, const uint32_t device_id, cudaStream_t stream) {
|
||||||
|
FillV2Kernel<<<CUDA_BLOCKS(device_id, output_size), CUDA_THREADS(device_id), 0, stream>>>(output_size, input, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const bool *input, bool *output,
|
||||||
|
const uint32_t device_id, cudaStream_t stream);
|
||||||
|
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const int8_t *input, int8_t *output,
|
||||||
|
const uint32_t device_id, cudaStream_t stream);
|
||||||
|
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const int16_t *input, int16_t *output,
|
||||||
|
const uint32_t device_id, cudaStream_t stream);
|
||||||
|
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const int32_t *input, int32_t *output,
|
||||||
|
const uint32_t device_id, cudaStream_t stream);
|
||||||
|
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const int64_t *input, int64_t *output,
|
||||||
|
const uint32_t device_id, cudaStream_t stream);
|
||||||
|
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const uint8_t *input, uint8_t *output,
|
||||||
|
const uint32_t device_id, cudaStream_t stream);
|
||||||
|
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const uint16_t *input, uint16_t *output,
|
||||||
|
const uint32_t device_id, cudaStream_t stream);
|
||||||
|
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const uint32_t *input, uint32_t *output,
|
||||||
|
const uint32_t device_id, cudaStream_t stream);
|
||||||
|
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const uint64_t *input, uint64_t *output,
|
||||||
|
const uint32_t device_id, cudaStream_t stream);
|
||||||
|
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const half *input, half *output,
|
||||||
|
const uint32_t device_id, cudaStream_t stream);
|
||||||
|
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const float *input, float *output,
|
||||||
|
const uint32_t device_id, cudaStream_t stream);
|
||||||
|
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const double *input, double *output,
|
||||||
|
const uint32_t device_id, cudaStream_t stream);
|
||||||
|
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const Complex<float> *input, Complex<float> *output,
|
||||||
|
const uint32_t device_id, cudaStream_t stream);
|
||||||
|
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const Complex<double> *input, Complex<double> *output,
|
||||||
|
const uint32_t device_id, cudaStream_t stream);
|
|
@ -0,0 +1,25 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_V2_IMPL_CUH_
|
||||||
|
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_V2_IMPL_CUH_
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const T *input, T *output, const uint32_t device_id,
|
||||||
|
cudaStream_t stream);
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_V2_IMPL_CUH_
|
|
@ -29,143 +29,73 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
MIND_API_OPERATOR_IMPL(FillV2, BaseOperator);
|
namespace {
|
||||||
abstract::ShapePtr FillV2InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr FillV2InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
if (!input_args[0]->isa<abstract::AbstractTensor>()) {
|
|
||||||
MS_EXCEPTION(TypeError) << "For '" << primitive->name() << "', input[0] must be tensor.";
|
|
||||||
}
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
const uint32_t kInputDims = 1;
|
auto prim_name = primitive->name();
|
||||||
|
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||||
|
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||||
|
|
||||||
auto max_length_ptr = primitive->GetAttr("max_length");
|
auto max_length_ptr = primitive->GetAttr("max_length");
|
||||||
MS_EXCEPTION_IF_NULL(max_length_ptr);
|
MS_EXCEPTION_IF_NULL(max_length_ptr);
|
||||||
int64_t max_length = GetValue<int64_t>(max_length_ptr);
|
const int64_t max_length = GetValue<int64_t>(max_length_ptr);
|
||||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
const int64_t kDimOne = 1;
|
||||||
if (input1_shape.size() != 1) {
|
const int64_t kDimZero = 0;
|
||||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
|
||||||
<< "', the shape size of 'input1' must be 1, but got: " << input1_shape.size() << ".";
|
CheckAndConvertUtils::CheckInteger("rank of shape", SizeToLong(input1_shape.size()), kEqual, kDimOne, prim_name);
|
||||||
}
|
if (!IsDynamic(input2_shape)) {
|
||||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
CheckAndConvertUtils::CheckInteger("rank of value", SizeToLong(input2_shape.size()), kEqual, kDimZero, prim_name);
|
||||||
if (input2_shape.size() != 0) {
|
|
||||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
|
||||||
<< "', the shape size of 'input2' must be 0, but got: " << input2_shape.size() << ".";
|
|
||||||
}
|
|
||||||
auto input_shape = input_args[0]->cast<abstract::AbstractTensorPtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(input_shape);
|
|
||||||
auto input_shape_value_ptr = input_shape->BuildValue();
|
|
||||||
MS_EXCEPTION_IF_NULL(input_shape_value_ptr);
|
|
||||||
auto input_shape_tensor = input_shape_value_ptr->cast<tensor::TensorPtr>();
|
|
||||||
auto input_type = input_args[0]->BuildType();
|
|
||||||
MS_EXCEPTION_IF_NULL(input_type);
|
|
||||||
auto input_type_id = input_type->cast<TensorTypePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(input_type_id);
|
|
||||||
auto input_type_element = input_type_id->element();
|
|
||||||
MS_EXCEPTION_IF_NULL(input_type_element);
|
|
||||||
auto shape_ptr = std::make_shared<abstract::Shape>(
|
|
||||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]);
|
|
||||||
auto shape_v = shape_ptr->shape();
|
|
||||||
if (shape_v.size() != kInputDims) {
|
|
||||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
|
||||||
<< "', input must be a 1-D tensor, but got a: " << shape_v.size() << "-D tensor.";
|
|
||||||
}
|
|
||||||
if (!input_args[0]->BuildValue()->isa<AnyValue>() && !input_args[0]->BuildValue()->isa<None>()) {
|
|
||||||
std::vector<int64_t> out_shape;
|
|
||||||
int64_t shape_m = 1;
|
|
||||||
if (input_type_element->type_id() == kNumberTypeInt32) {
|
|
||||||
auto input_shape_ptr = reinterpret_cast<int32_t *>(input_shape_tensor->data_c());
|
|
||||||
for (auto i = 0; i < shape_v[0]; ++i) {
|
|
||||||
if (input_shape_ptr[i] > 0) {
|
|
||||||
out_shape.push_back(input_shape_ptr[i]);
|
|
||||||
shape_m *= input_shape_ptr[i];
|
|
||||||
} else {
|
|
||||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
|
||||||
<< "', each dimension of input shape must be greater than 0, but got input shape "
|
|
||||||
<< i << ": " << input_shape_ptr[i] << ".";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (input_type_element->type_id() == kNumberTypeInt64) {
|
|
||||||
auto input_shape_ptr = reinterpret_cast<int64_t *>(input_shape_tensor->data_c());
|
|
||||||
for (auto i = 0; i < shape_v[0]; ++i) {
|
|
||||||
if (input_shape_ptr[i] > 0) {
|
|
||||||
out_shape.push_back(input_shape_ptr[i]);
|
|
||||||
shape_m *= static_cast<int64_t>(input_shape_ptr[i]);
|
|
||||||
} else {
|
|
||||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
|
||||||
<< "', each dimension of input shape must be greater than 0, but got input shape "
|
|
||||||
<< i << ": " << input_shape_ptr[i] << ".";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
|
|
||||||
<< "', the dtype of input1 must be in [int32, int64], but got: "
|
|
||||||
<< input_type_element->type_id() << ".";
|
|
||||||
}
|
|
||||||
if (shape_m > max_length) {
|
|
||||||
MS_EXCEPTION(ValueError)
|
|
||||||
<< "For '" << primitive->name()
|
|
||||||
<< "', the number of elements of output must be less than 'max_length', but got number of elements: " << shape_m
|
|
||||||
<< ", 'max_length': " << max_length << ".";
|
|
||||||
}
|
|
||||||
return std::make_shared<abstract::Shape>(out_shape);
|
|
||||||
} else {
|
|
||||||
std::vector<int64_t> output_shape;
|
|
||||||
if (shape_v[0] > 0) {
|
|
||||||
for (int i = 0; i < shape_v[0]; i++) {
|
|
||||||
output_shape.push_back(abstract::Shape::kShapeDimAny);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (uint16_t i = 0; i < shape_v.size(); i++) {
|
|
||||||
output_shape.push_back(abstract::Shape::kShapeDimAny);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (input_args[kInputIndex0]->isa<abstract::AbstractTensor>() &&
|
||||||
|
input_args[kInputIndex0]->BuildValue()->isa<tensor::Tensor>()) {
|
||||||
|
auto value_ptr = input_args[kInputIndex0]->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||||
|
auto output_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", value_ptr, prim_name);
|
||||||
|
for (size_t i = 0; i < output_shape.size(); ++i) {
|
||||||
|
CheckAndConvertUtils::CheckInteger("the " + std::to_string(i) + "th dimension of input shape", output_shape[i],
|
||||||
|
kGreaterThan, kDimZero, prim_name);
|
||||||
}
|
}
|
||||||
|
CheckAndConvertUtils::CheckInteger("the number of elements of output", SizeToLong(SizeOf(output_shape)), kLessEqual,
|
||||||
|
max_length, prim_name);
|
||||||
return std::make_shared<abstract::Shape>(output_shape);
|
return std::make_shared<abstract::Shape>(output_shape);
|
||||||
|
} else {
|
||||||
|
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr FillV2InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr FillV2InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
|
auto input1_type = input_args[kInputIndex0]->BuildType();
|
||||||
|
auto input2_type = input_args[kInputIndex1]->BuildType();
|
||||||
|
|
||||||
// Check the data type of the first input
|
// Check the data type of the first input
|
||||||
auto input1 = input_args[kInputIndex0];
|
|
||||||
auto input1_type = input1->BuildType();
|
|
||||||
MS_EXCEPTION_IF_NULL(input1);
|
|
||||||
if (input1->isa<abstract::AbstractTensor>()) {
|
|
||||||
const std::set<TypePtr> input1_valid_types = {kInt32, kInt64};
|
const std::set<TypePtr> input1_valid_types = {kInt32, kInt64};
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input1 datatype", input1_type, input1_valid_types, prim_name);
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("input1 datatype", input1_type, input1_valid_types, prim_name);
|
||||||
} else {
|
|
||||||
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
|
|
||||||
<< "', the dtype of input1 must be in [int32, int64], but got: " << input1_type->ToString()
|
|
||||||
<< ".";
|
|
||||||
}
|
|
||||||
// Check the data type of the second input and infer the data type of the output from the second input
|
// Check the data type of the second input and infer the data type of the output from the second input
|
||||||
auto input2 = input_args[kInputIndex1];
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("output datatype", input2_type,
|
||||||
auto input2_type = input2->BuildType();
|
common_valid_types_with_complex_and_bool, prim_name);
|
||||||
MS_EXCEPTION_IF_NULL(input2);
|
|
||||||
if (input2->isa<abstract::AbstractTensor>()) {
|
|
||||||
auto output_valid_types = common_valid_types;
|
|
||||||
(void)output_valid_types.insert(kBool);
|
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("output datatype", input2_type, output_valid_types, prim_name);
|
|
||||||
} else {
|
|
||||||
MS_EXCEPTION(TypeError)
|
|
||||||
<< "For '" << prim_name
|
|
||||||
<< "', the dtype of input2 must be in [bool, int8, int16, int32, int64, uint8, uint16, uint32, "
|
|
||||||
"uint64, float16, float32, float64], but got: "
|
|
||||||
<< input2_type->ToString() << ".";
|
|
||||||
}
|
|
||||||
auto input2_tensor_type = (input2_type->cast<TensorTypePtr>())->element();
|
|
||||||
|
|
||||||
return input2_tensor_type;
|
return input2_type;
|
||||||
}
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
AbstractBasePtr FillV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr FillV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
|
for (auto &input : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
|
}
|
||||||
const int64_t input_num = 2;
|
const int64_t input_num = 2;
|
||||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim_name);
|
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||||
auto infer_type = FillV2InferType(primitive, input_args);
|
auto infer_type = FillV2InferType(primitive, input_args);
|
||||||
auto infer_shape = FillV2InferShape(primitive, input_args);
|
auto infer_shape = FillV2InferShape(primitive, input_args);
|
||||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MIND_API_OPERATOR_IMPL(FillV2, BaseOperator);
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(FillV2, prim::kPrimFillV2, FillV2Infer, nullptr, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(FillV2, prim::kPrimFillV2, FillV2Infer, nullptr, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import pytest
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import context, nn, Tensor
|
||||||
|
from mindspore.ops.operations import _inner_ops as inner
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.op = inner.FillV2()
|
||||||
|
|
||||||
|
def construct(self, shape, value):
|
||||||
|
return self.op(shape, value)
|
||||||
|
|
||||||
|
|
||||||
|
def dyn_case():
|
||||||
|
net = Net()
|
||||||
|
|
||||||
|
shape_dyn = Tensor(shape=[None], dtype=ms.int32)
|
||||||
|
value_dyn = Tensor(shape=[None], dtype=ms.complex64)
|
||||||
|
net.set_inputs(shape_dyn, value_dyn)
|
||||||
|
|
||||||
|
shape = Tensor([2, 3], dtype=ms.int32)
|
||||||
|
value = Tensor(1 + 2j, dtype=ms.complex64)
|
||||||
|
out = net(shape, value)
|
||||||
|
|
||||||
|
assert out.asnumpy().shape == (2, 3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fill_v2_dyn():
|
||||||
|
"""
|
||||||
|
Feature: test FillV2 dynamic shape in gpu.
|
||||||
|
Description: inputs is dynamic shape.
|
||||||
|
Expectation: expect correct shape result.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||||
|
dyn_case()
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||||
|
dyn_case()
|
Loading…
Reference in New Issue