!44801 new kernel mod: gpu_convert_to_dynamic_shape
Merge pull request !44801 from Yanzhi_YI/gpuconverttodynamicshape
This commit is contained in:
commit
4bddb6882d
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,25 +14,25 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_GPU_CONVERT_TO_DYNAMIC_SHAPE_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_GPU_CONVERT_TO_DYNAMIC_SHAPE_GPU_KERNEL_H
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_OTHER_GPU_CONVERT_TO_DYNAMIC_SHAPE_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_OTHER_GPU_CONVERT_TO_DYNAMIC_SHAPE_GPU_KERNEL_H
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class GpuConvertToDynamicShapeGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
class GpuConvertToDynamicShapeGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
GpuConvertToDynamicShapeGpuKernelMod() { ResetResource(); }
|
||||
~GpuConvertToDynamicShapeGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
if (input_shape_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
|
@ -40,58 +40,57 @@ class GpuConvertToDynamicShapeGpuKernelMod : public DeprecatedNativeGpuKernelMod
|
|||
T *output_device_address = GetDeviceAddress<T>(outputs, 0);
|
||||
cuda_stream_ptr_ = stream_ptr;
|
||||
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(output_device_address, input_device_address, input_size_ * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Failed to copy gpu memory.");
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
|
||||
cudaMemcpyAsync(output_device_address, input_device_address, input_size_ * sizeof(T), cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Failed to copy gpu memory.");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
size_t input_count = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_count != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs should be 1, but got " << input_count;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
constexpr size_t input_num = 1;
|
||||
if (inputs.size() != input_num) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be 1, but got " << inputs.size();
|
||||
}
|
||||
|
||||
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shape_, kernel_name, "input");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
input_size_ = 1;
|
||||
for (const auto &e : input_shape_) {
|
||||
input_size_ *= e;
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
is_need_retrieve_output_shape_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != KRET_OK && ret != KRET_UNKNOWN_OUT_SHAPE) {
|
||||
return ret;
|
||||
}
|
||||
outputs_ = outputs;
|
||||
input_shape_ = inputs[0]->GetShapeVector();
|
||||
outputs_[0]->SetShapeVector(input_shape_);
|
||||
input_size_ = 1;
|
||||
for (const auto &e : input_shape_) {
|
||||
input_size_ *= e;
|
||||
}
|
||||
InitSizeLists();
|
||||
return ret;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept {
|
||||
cuda_stream_ptr_ = nullptr;
|
||||
input_shape_.clear();
|
||||
input_size_ = 1;
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SyncData() override {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(cuda_stream_ptr_)),
|
||||
"cudaStreamSynchronized failed");
|
||||
std::vector<KernelTensorPtr> GetOutputs() override { return outputs_; }
|
||||
|
||||
std::vector<TypeId> output_types = {common::AnfAlgo::GetOutputInferDataType(kernel_node_.lock(), 0)};
|
||||
std::vector<ShapeVector> output_shapes = {input_shape_};
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(output_types, output_shapes, kernel_node_.lock().get());
|
||||
}
|
||||
void InitSizeLists() override {
|
||||
protected:
|
||||
void InitSizeLists() {
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
output_size_list_.push_back(input_size_ * sizeof(T));
|
||||
}
|
||||
|
@ -100,9 +99,9 @@ class GpuConvertToDynamicShapeGpuKernelMod : public DeprecatedNativeGpuKernelMod
|
|||
void *cuda_stream_ptr_;
|
||||
ShapeVector input_shape_;
|
||||
int64_t input_size_;
|
||||
bool is_null_input_;
|
||||
std::vector<KernelTensorPtr> outputs_{};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_GPU_CONVERT_TO_DYNAMIC_SHAPE_GPU_KERNEL_H
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_OTHER_GPU_CONVERT_TO_DYNAMIC_SHAPE_GPU_KERNEL_H
|
||||
|
|
|
@ -190,8 +190,6 @@ AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &engine_ptr, const Pri
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMapUniform(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -418,22 +418,6 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri
|
|||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string &op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
|
||||
ShapeVector input_shape = input->shape()->shape();
|
||||
int32_t input_rank = SizeToInt(input_shape.size());
|
||||
ShapeVector inferred_shape(input_rank, Shape::kShapeDimAny);
|
||||
ShapeVector min_shape(input_rank, 1);
|
||||
ShapeVector max_shape = input_shape;
|
||||
|
||||
ShapePtr shape = std::make_shared<Shape>(inferred_shape, min_shape, max_shape);
|
||||
return std::make_shared<AbstractTensor>(input->element(), shape);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: Ref/Tensor, universal
|
||||
|
|
|
@ -342,8 +342,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimUpdateState, R{InferImplUpdateState, nullptr, true}},
|
||||
// Debug
|
||||
{prim::kPrimDebug, R{InferImplDebug, nullptr, true}},
|
||||
// Dynamic shape testing
|
||||
{prim::kPrimGpuConvertToDynamicShape, R{InferImplGpuConvertToDynamicShape, nullptr, true}},
|
||||
// RowTensor
|
||||
{prim::kPrimMakeRowTensor, R{InferImplMakeRowTensor, nullptr, true}},
|
||||
{prim::kPrimRowTensorGetValues, R{InferImplRowTensorGetValues, nullptr, true}},
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/gpu_convert_to_dynamic_shape.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr GpuConvertToDynamicShapeInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
if (IsDynamicRank(input_shape)) {
|
||||
return std::make_shared<abstract::Shape>(ShapeVector{abstract::Shape::kShapeRankAny});
|
||||
}
|
||||
ShapeVector output_shape_dyn;
|
||||
for (size_t i = 0; i < input_shape.size(); ++i) {
|
||||
output_shape_dyn.push_back(abstract::Shape::kShapeDimAny);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(output_shape_dyn);
|
||||
}
|
||||
|
||||
TypePtr GpuConvertToDynamicShapeInferType(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(GpuConvertToDynamicShape, BaseOperator);
|
||||
AbstractBasePtr GpuConvertToDynamicShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputNum = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, primitive->name());
|
||||
auto infer_type = GpuConvertToDynamicShapeInferType(primitive, input_args);
|
||||
auto infer_shape = GpuConvertToDynamicShapeInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(GpuConvertToDynamicShape, prim::kPrimGpuConvertToDynamicShape,
|
||||
GpuConvertToDynamicShapeInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CORE_OPS_GPU_CONVERT_TO_DYNAMIC_SHAPE_H_
|
||||
#define MINDSPORE_CORE_OPS_GPU_CONVERT_TO_DYNAMIC_SHAPE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameGpuConvertToDynamicShape = "GpuConvertToDynamicShape";
|
||||
/// \brief Gpu convert to dynamic shape.
|
||||
class MIND_API GpuConvertToDynamicShape : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(GpuConvertToDynamicShape);
|
||||
/// \brief Constructor.
|
||||
GpuConvertToDynamicShape() : BaseOperator(kNameGpuConvertToDynamicShape) { InitIOName({"input"}, {"output"}); }
|
||||
|
||||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr GpuConvertToDynamicShapeInfer(const abstract::AnalysisEnginePtr &,
|
||||
const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimGpuConvertToDynamicShapePtr = std::shared_ptr<GpuConvertToDynamicShape>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_GPU_CONVERT_TO_DYNAMIC_SHAPE_H_
|
Loading…
Reference in New Issue