add dynamic kernels
This commit is contained in:
parent
bf0142ae4b
commit
b17d0a08c9
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/gpu/other/dynamic_broadcast_grad_args_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(DynamicBroadcastGradientArgs,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicBroadcastGradientArgsGpuKernel, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(DynamicBroadcastGradientArgs,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicBroadcastGradientArgsGpuKernel, int32_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(DynamicBroadcastGradientArgs,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicBroadcastGradientArgsGpuKernel, uint64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(DynamicBroadcastGradientArgs,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicBroadcastGradientArgsGpuKernel, uint32_t, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,209 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_BRAODCAST_GRADIENT_ARGS_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_BRAODCAST_GRADIENT_ARGS_GPU_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kInputNum = 2;
|
||||
template <typename T, typename S>
|
||||
class DynamicBroadcastGradientArgsGpuKernel : public GpuKernel {
|
||||
public:
|
||||
DynamicBroadcastGradientArgsGpuKernel() : r0_size_(0), r1_size_(0) { ResetResource(); }
|
||||
~DynamicBroadcastGradientArgsGpuKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto s0_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
auto s1_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
auto r0_addr = GetDeviceAddress<S>(outputs, 0);
|
||||
auto r1_addr = GetDeviceAddress<S>(outputs, 1);
|
||||
std::vector<T> x0_value(input_size_list_[0] / sizeof(T), 0);
|
||||
std::vector<T> x1_value(input_size_list_[1] / sizeof(T), 0);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudaMemcpyAsync(&x0_value[0], s0_addr, input_size_list_[0], cudaMemcpyDeviceToHost, cuda_stream),
|
||||
"DynamicBroadcastGradientArgs copy s0 value failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudaMemcpyAsync(&x1_value[0], s1_addr, input_size_list_[1], cudaMemcpyDeviceToHost, cuda_stream),
|
||||
"DynamicBroadcastGradientArgs copy s1 value failed");
|
||||
auto grad_reduce_idx = CalOut({x0_value, x1_value});
|
||||
r0_size_ = SetOuputValue(r0_addr, grad_reduce_idx[0], x0_value.size(), cuda_stream);
|
||||
r1_size_ = SetOuputValue(r1_addr, grad_reduce_idx[1], x1_value.size(), cuda_stream);
|
||||
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != kInputNum) {
|
||||
MS_LOG(EXCEPTION) << "DynamicBroadcastGradiendArgs needs " << kInputNum << " inputs, but get " << input_num;
|
||||
}
|
||||
auto s0_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
auto s1_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
|
||||
auto r0_shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
auto r1_shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 1);
|
||||
if (s0_shape.size() != 1 || s1_shape.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Inputs must be [1-D], but get " << s0_shape.size() << "-D and " << s1_shape.size() << "-D.";
|
||||
}
|
||||
|
||||
auto s0_size = std::accumulate(s0_shape.begin(), s0_shape.end(), sizeof(T), std::multiplies<size_t>());
|
||||
auto s1_size = std::accumulate(s1_shape.begin(), s1_shape.end(), sizeof(T), std::multiplies<size_t>());
|
||||
|
||||
input_size_list_.push_back(s0_size);
|
||||
input_size_list_.push_back(s1_size);
|
||||
output_size_list_.push_back(r0_shape[0] * sizeof(S));
|
||||
output_size_list_.push_back(r1_shape[0] * sizeof(S));
|
||||
return true;
|
||||
}
|
||||
void ResetResource() noexcept override {
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
void PostExecute() override {
|
||||
std::vector<size_t> r0_shape{r0_size_};
|
||||
std::vector<size_t> r1_shape{r1_size_};
|
||||
AnfAlgo::SetOutputInferTypeAndShape({TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt64}, {r0_shape, r1_shape},
|
||||
kernel_node_.lock().get());
|
||||
MS_LOG(DEBUG) << "Run PostExecute for DynamicBroadcastGradientArgs, real r0 shape is " << r0_shape
|
||||
<< ", r1 shape is " << r1_shape;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override{};
|
||||
|
||||
private:
|
||||
std::vector<std::vector<T>> CalOut(const std::vector<std::vector<T>> &input_shapes) {
|
||||
std::vector<std::vector<T>> grad_reduce_idx(kInputNum);
|
||||
bool all_equal = true;
|
||||
size_t max_rank = 0;
|
||||
for (size_t i = 0; i < kInputNum; i++) {
|
||||
if (input_shapes[i] != input_shapes[0]) {
|
||||
all_equal = false;
|
||||
}
|
||||
if (input_shapes[i].size() > max_rank) {
|
||||
max_rank = input_shapes[i].size();
|
||||
}
|
||||
}
|
||||
if (all_equal) {
|
||||
return grad_reduce_idx;
|
||||
}
|
||||
// Reverse shapes
|
||||
std::vector<std::vector<T>> reverse_shapes(kInputNum);
|
||||
for (size_t i = 0; i < kInputNum; i++) {
|
||||
reverse_shapes[i] = input_shapes[i];
|
||||
std::reverse(reverse_shapes[i].begin(), reverse_shapes[i].end());
|
||||
if (reverse_shapes[i].size() < max_rank) {
|
||||
reverse_shapes[i].resize(max_rank, 1);
|
||||
}
|
||||
}
|
||||
grad_reduce_idx = GetGradIndex(reverse_shapes, max_rank);
|
||||
return grad_reduce_idx;
|
||||
}
|
||||
std::vector<std::vector<T>> GetGradIndex(const std::vector<std::vector<T>> &revers_shapes, const size_t max_rank) {
|
||||
std::vector<std::vector<T>> grad_reduce_index(kInputNum);
|
||||
bool pre_one[kInputNum];
|
||||
bool cur_one[kInputNum];
|
||||
for (size_t i = 0; i < kInputNum; i++) {
|
||||
pre_one[i] = false;
|
||||
cur_one[i] = false;
|
||||
}
|
||||
bool set_one = false;
|
||||
for (size_t j = 0; j < max_rank; j++) {
|
||||
int out_dim = -1;
|
||||
bool out_dim_set = false;
|
||||
bool none_one = true;
|
||||
for (size_t i = 0; i < kInputNum; i++) {
|
||||
if (revers_shapes[i][j] == 1) {
|
||||
cur_one[i] = true;
|
||||
none_one = false;
|
||||
} else {
|
||||
cur_one[i] = false;
|
||||
if (!out_dim_set || revers_shapes[i][j] == static_cast<T>(out_dim)) {
|
||||
out_dim = static_cast<int>(revers_shapes[i][j]);
|
||||
out_dim_set = true;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Can not broadcast inputs[0] and inputs[1].";
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!out_dim_set) {
|
||||
for (size_t i = 0; i < kInputNum; i++) {
|
||||
(void)grad_reduce_index[i].emplace_back(max_rank - 1 - j);
|
||||
}
|
||||
continue;
|
||||
} else if (std::equal(cur_one, cur_one + kInputNum, pre_one) && set_one) {
|
||||
for (size_t i = 0; i < kInputNum; i++) {
|
||||
if (cur_one[i] && !none_one) {
|
||||
(void)grad_reduce_index[i].emplace_back(max_rank - 1 - j);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < kInputNum; i++) {
|
||||
if (cur_one[i] && !none_one) {
|
||||
(void)grad_reduce_index[i].emplace_back(max_rank - 1 - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
set_one = true;
|
||||
for (size_t i = 0; i < kInputNum; i++) {
|
||||
pre_one[i] = cur_one[i];
|
||||
}
|
||||
}
|
||||
return grad_reduce_index;
|
||||
}
|
||||
size_t SetOuputValue(S *addr, const std::vector<T> grad_reduce_idx, size_t input_num, cudaStream_t stream) {
|
||||
std::vector<S> output;
|
||||
size_t index_num = grad_reduce_idx.size();
|
||||
for (size_t i = 0; i < index_num; i++) {
|
||||
output.push_back(static_cast<S>(grad_reduce_idx[index_num - 1 - i]));
|
||||
}
|
||||
size_t out_size = index_num;
|
||||
if (index_num == 0) {
|
||||
out_size = input_num;
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
output.push_back(static_cast<S>(i));
|
||||
}
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(addr, &output[0], out_size * sizeof(S), cudaMemcpyHostToDevice, stream),
|
||||
"DynamicBroadcastGradientArgs copy output failed");
|
||||
return out_size;
|
||||
}
|
||||
size_t r0_size_;
|
||||
size_t r1_size_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_BRAODCAST_GRADIENT_ARGS_GPU_KERNEL_H_
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/gpu/other/dynamic_broadcastto_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
DynamicBroadcastToGpuKernel, double, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
DynamicBroadcastToGpuKernel, float, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
DynamicBroadcastToGpuKernel, half, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
DynamicBroadcastToGpuKernel, int16_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicBroadcastToGpuKernel, int32_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicBroadcastToGpuKernel, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
DynamicBroadcastToGpuKernel, double, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
DynamicBroadcastToGpuKernel, float, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
DynamicBroadcastToGpuKernel, half, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
DynamicBroadcastToGpuKernel, int16_t, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicBroadcastToGpuKernel, int32_t, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicBroadcastToGpuKernel, int64_t, int32_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,139 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_BRAODCASTTO_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_BRAODCASTTO_GPU_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t SHAPE_SIZE = 4;
|
||||
constexpr size_t kIndex2 = 2;
|
||||
constexpr size_t kIndex3 = 3;
|
||||
template <typename T, typename S>
|
||||
class DynamicBroadcastToGpuKernel : public GpuKernel {
|
||||
public:
|
||||
DynamicBroadcastToGpuKernel() : shape_size_(0), is_null_input_(false) { ResetResource(); }
|
||||
~DynamicBroadcastToGpuKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto data_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
auto shape_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
BroadcastTo(input_shape_[0], input_shape_[1], input_shape_[kIndex2], input_shape_[kIndex3], output_shape_[0],
|
||||
output_shape_[1], output_shape_[kIndex2], output_shape_[kIndex3], data_addr, output_addr, cuda_stream);
|
||||
real_output_shape_ = std::vector<S>(input_size_list_[1] / sizeof(S), 0);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudaMemcpyAsync(&real_output_shape_[0], shape_addr, input_size_list_[1], cudaMemcpyDeviceToHost, cuda_stream),
|
||||
"DynamicBroadcastTo copy real output shape value failed");
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
auto input_shapes = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
auto shape_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
|
||||
auto output_shapes = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shapes) || CHECK_NULL_INPUT(output_shapes) || CHECK_NULL_INPUT(shape_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "For 'BroadcastToGpuKernel', input or output is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (input_shapes.size() > SHAPE_SIZE || output_shapes.size() > SHAPE_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "BroadcastTo operation does not support dim greater than " << SHAPE_SIZE;
|
||||
}
|
||||
|
||||
if (output_shapes.size() < input_shapes.size()) {
|
||||
MS_LOG(EXCEPTION) << "The rank of BroadcastTo's output [" << output_shapes.size()
|
||||
<< "] cannot be smaller than the rank of the input [" << input_shapes.size() << "].";
|
||||
}
|
||||
|
||||
shape_size_ = std::accumulate(shape_shape.begin(), shape_shape.end(), sizeof(S), std::multiplies<size_t>());
|
||||
|
||||
size_t offset = output_shapes.size() - input_shapes.size();
|
||||
for (size_t i = 0; i < input_shapes.size(); i++) {
|
||||
input_shape_[i + offset] = input_shapes[i];
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < output_shapes.size(); j++) {
|
||||
output_shape_[j] = (output_shapes[j] > 0 ? output_shapes[j] : input_shapes[j]);
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
void ResetResource() noexcept override {
|
||||
real_output_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
for (size_t i = 0; i < SHAPE_SIZE; i++) {
|
||||
input_shape_[i] = 1;
|
||||
output_shape_[i] = 1;
|
||||
}
|
||||
}
|
||||
void PostExecute() override {
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(kernel_node_.lock(), 0);
|
||||
std::vector<size_t> output_shape;
|
||||
std::transform(real_output_shape_.begin(), real_output_shape_.end(), std::back_inserter(output_shape),
|
||||
[](const S &i) { return static_cast<size_t>(i); });
|
||||
AnfAlgo::SetOutputInferTypeAndShape({data_type}, {output_shape}, kernel_node_.lock().get());
|
||||
MS_LOG(DEBUG) << "Run PostExecute for DynamicBroadcastTo, real output shape is " << output_shape;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_shape_[0] * input_shape_[1] * input_shape_[kIndex2] * input_shape_[kIndex3] *
|
||||
sizeof(T));
|
||||
input_size_list_.push_back(shape_size_);
|
||||
output_size_list_.push_back(output_shape_[0] * output_shape_[1] * output_shape_[kIndex2] * output_shape_[kIndex3] *
|
||||
sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t shape_size_;
|
||||
size_t input_shape_[SHAPE_SIZE] = {1, 1, 1, 1};
|
||||
size_t output_shape_[SHAPE_SIZE] = {1, 1, 1, 1};
|
||||
bool is_null_input_ = false;
|
||||
std::vector<S> real_output_shape_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_BRAODCASTTO_GPU_KERNEL_H_
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/gpu/other/dynamic_reshape_gpu_kernel.h"
|
||||
#include <iterator>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "runtime/device/gpu/gpu_common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicReshape,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
DynamicReshapeKernel, double, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicReshape,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
DynamicReshapeKernel, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicReshape,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicReshapeKernel, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicReshape,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicReshapeKernel, int64_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicReshape,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
DynamicReshapeKernel, double, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicReshape,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
DynamicReshapeKernel, float, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicReshape,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicReshapeKernel, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
DynamicReshape,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicReshapeKernel, int, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,106 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_RESHAPE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_RESHAPE_GPU_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class DynamicReshapeKernel : public GpuKernel {
|
||||
public:
|
||||
DynamicReshapeKernel() : data_type_size_(0), shape_size_(0) { ResetResource(); }
|
||||
~DynamicReshapeKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto data_addr = GetDeviceAddress<unsigned char>(inputs, 0);
|
||||
auto shape_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
auto output_addr = GetDeviceAddress<unsigned char>(outputs, 0);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudaMemcpyAsync(output_addr, data_addr, input_size_list_[0], cudaMemcpyDeviceToDevice, cuda_stream),
|
||||
"DynamicReshape cpy data failed");
|
||||
real_output_shape_ = std::vector<S>(input_size_list_[1] / sizeof(S), 0);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudaMemcpyAsync(&real_output_shape_[0], shape_addr, input_size_list_[1], cudaMemcpyDeviceToHost, cuda_stream),
|
||||
"DynamicReshape cpy real output shape value failed");
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
auto output_shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
auto input_x_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
auto input_shape_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
data_type_size_ = mindspore::kernel::GetDtypeNbyte(TypeIdToString(data_type, true));
|
||||
shape_size_ = input_shape_shape.size();
|
||||
size_t input_x_size =
|
||||
std::accumulate(input_x_shape.begin(), input_x_shape.end(), data_type_size_, std::multiplies<size_t>());
|
||||
input_size_list_.push_back(input_x_size);
|
||||
size_t input_shape_size =
|
||||
std::accumulate(input_shape_shape.begin(), input_shape_shape.end(), sizeof(S), std::multiplies<size_t>());
|
||||
input_size_list_.push_back(input_shape_size);
|
||||
size_t output_size =
|
||||
std::accumulate(output_shape.begin(), output_shape.end(), data_type_size_, std::multiplies<size_t>());
|
||||
output_size_list_.push_back(output_size);
|
||||
|
||||
return true;
|
||||
}
|
||||
void ResetResource() noexcept override {
|
||||
real_output_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
void PostExecute() override {
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(kernel_node_.lock(), 0);
|
||||
std::vector<size_t> output_shape;
|
||||
std::transform(real_output_shape_.begin(), real_output_shape_.end(), std::back_inserter(output_shape),
|
||||
[](const S &value) { return static_cast<size_t>(value); });
|
||||
AnfAlgo::SetOutputInferTypeAndShape({data_type}, {output_shape}, kernel_node_.lock().get());
|
||||
MS_LOG(DEBUG) << "Run PostExecute for DynamicReshape, real output shape is " << output_shape;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override { return; }
|
||||
|
||||
private:
|
||||
size_t data_type_size_;
|
||||
size_t shape_size_;
|
||||
std::vector<S> real_output_shape_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_RESHAPE_GPU_KERNEL_H_
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.d_shape = ops.DynamicShape()
|
||||
self.d_broadcastto = inner.DynamicBroadcastTo()
|
||||
|
||||
def construct(self, data, shape):
|
||||
shape = self.d_shape(shape)
|
||||
return self.d_broadcastto(data, shape)
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_float32():
|
||||
"""
|
||||
Feature: Dynamic BroadcastTo.
|
||||
Description: test cases for dynamic_broadcastto.
|
||||
Expectation: the result match expected array.
|
||||
"""
|
||||
data = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
||||
shape = Tensor(np.zeros((2, 3)), mindspore.int64)
|
||||
expect_data = np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32)
|
||||
net = Net()
|
||||
output = net(data, shape)
|
||||
print(output.asnumpy())
|
||||
assert np.array_equal(output.asnumpy(), expect_data)
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _inner_ops as ops
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.d_reshape = ops.DynamicReshape()
|
||||
|
||||
def construct(self, data, shape):
|
||||
return self.d_reshape(data, shape)
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_float32():
|
||||
"""
|
||||
Feature: Dynamci Reshape.
|
||||
Description: test cases for dynamicreshape.
|
||||
Expectation: the result match expected array.
|
||||
"""
|
||||
data = Tensor(np.arange(1, 9).reshape((2, 4)), mindspore.float32)
|
||||
shape = Tensor(np.array([4, 2]), mindspore.int64)
|
||||
expect_data = np.arange(1, 9).reshape((4, 2))
|
||||
print(data)
|
||||
print(shape)
|
||||
net = Net()
|
||||
output = net(data, shape)
|
||||
print(output.asnumpy())
|
||||
assert np.array_equal(output.asnumpy(), expect_data)
|
Loading…
Reference in New Issue