optimize infer
This commit is contained in:
parent
56ddd6c010
commit
c0421eec03
|
@ -859,7 +859,7 @@ void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_i
|
||||||
MS_LOG(DEBUG) << "Run " << primitive->name() << " by pyboost";
|
MS_LOG(DEBUG) << "Run " << primitive->name() << " by pyboost";
|
||||||
graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index], true,
|
graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index], true,
|
||||||
&input_info);
|
&input_info);
|
||||||
kernel::pyboost::OpRunnerInfo op_runner_info{
|
runtime::OpRunnerInfo op_runner_info{
|
||||||
primitive, device_target, input_info.input_values, input_info.input_abs, {}, kernel->abstract()};
|
primitive, device_target, input_info.input_values, input_info.input_abs, {}, kernel->abstract()};
|
||||||
runtime::PyBoostOpExecute::GetInstance().RunPyBoostCall(&op_runner_info, &op_outputs);
|
runtime::PyBoostOpExecute::GetInstance().RunPyBoostCall(&op_runner_info, &op_outputs);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -37,15 +37,6 @@ namespace pyboost {
|
||||||
using GradFunc = std::function<void()>;
|
using GradFunc = std::function<void()>;
|
||||||
constexpr size_t kAbstractCacheSize = 8192;
|
constexpr size_t kAbstractCacheSize = 8192;
|
||||||
|
|
||||||
struct OpRunnerInfo {
|
|
||||||
const PrimitivePtr &prim;
|
|
||||||
const std::string &device_target;
|
|
||||||
const vector<ValuePtr> &inputs;
|
|
||||||
const abstract::AbstractBasePtrList &inputs_abs;
|
|
||||||
const std::vector<InputType> &inputs_mask;
|
|
||||||
abstract::AbstractBasePtr output_abs;
|
|
||||||
};
|
|
||||||
|
|
||||||
// OpRunner is a base class for operators.
|
// OpRunner is a base class for operators.
|
||||||
// OpRunner records the operator's input abstract,
|
// OpRunner records the operator's input abstract,
|
||||||
// output abstract and output Tensors for grad,
|
// output abstract and output Tensors for grad,
|
||||||
|
@ -124,40 +115,18 @@ class BACKEND_EXPORT OpRunner : public std::enable_shared_from_this<OpRunner> {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... T>
|
template <typename... T>
|
||||||
void GenerateAbstract(T &...args) {
|
void GenerateAbstract(T &... args) {
|
||||||
(input_abs_.emplace_back(ConvertAbstract(args)), ...);
|
(input_abs_.emplace_back(ConvertAbstract(args)), ...);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Member function for Infer and creating output tensors.
|
// Member function for Infer and creating output tensors.
|
||||||
template <typename... T>
|
template <typename... T>
|
||||||
void InferOutput(T &...args) {
|
void InferOutput(T &... args) {
|
||||||
runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyBoostInferOutput,
|
runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyBoostInferOutput,
|
||||||
primitive_->name(), false);
|
primitive_->name(), false);
|
||||||
(input_abs_.emplace_back(ConvertAbstract(args)), ...);
|
(input_abs_.emplace_back(ConvertAbstract(args)), ...);
|
||||||
output_abs_ = PyBoostUtils::InferByOpDef(primitive_, input_abs_);
|
output_abs_ = PyBoostUtils::InferByOpDef(primitive_, input_abs_);
|
||||||
MS_EXCEPTION_IF_NULL(output_abs_);
|
MS_EXCEPTION_IF_NULL(output_abs_);
|
||||||
CreateOutput();
|
|
||||||
}
|
|
||||||
|
|
||||||
void InferOutput(OpRunnerInfo *op_runner_info) {
|
|
||||||
MS_EXCEPTION_IF_NULL(op_runner_info);
|
|
||||||
runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyBoostInferOutput,
|
|
||||||
primitive_->name(), false);
|
|
||||||
if (op_runner_info->inputs_abs.empty()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Get empty input abstract";
|
|
||||||
}
|
|
||||||
input_abs_ = op_runner_info->inputs_abs;
|
|
||||||
if (op_runner_info->output_abs == nullptr) {
|
|
||||||
output_abs_ = PyBoostUtils::InferByOpDef(primitive_, input_abs_);
|
|
||||||
MS_EXCEPTION_IF_NULL(output_abs_);
|
|
||||||
op_runner_info->output_abs = output_abs_;
|
|
||||||
} else {
|
|
||||||
output_abs_ = op_runner_info->output_abs;
|
|
||||||
}
|
|
||||||
CreateOutput();
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateOutput() {
|
|
||||||
MS_LOG(DEBUG) << "PyBoost infer output " << output_abs_->ToString();
|
MS_LOG(DEBUG) << "PyBoost infer output " << output_abs_->ToString();
|
||||||
PyBoostUtils::CreateOutputTensor(output_abs_, &outputs_);
|
PyBoostUtils::CreateOutputTensor(output_abs_, &outputs_);
|
||||||
abstract_cache_.Push(output_abs_);
|
abstract_cache_.Push(output_abs_);
|
||||||
|
@ -165,7 +134,7 @@ class BACKEND_EXPORT OpRunner : public std::enable_shared_from_this<OpRunner> {
|
||||||
|
|
||||||
// A static function used for the "customize" operator to generate the operator's output Tensor.
|
// A static function used for the "customize" operator to generate the operator's output Tensor.
|
||||||
template <typename... T>
|
template <typename... T>
|
||||||
static void InferOpOutput(const std::shared_ptr<OpRunner> &op, T &...args) {
|
static void InferOpOutput(const std::shared_ptr<OpRunner> &op, T &... args) {
|
||||||
runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyBoostInferOutput,
|
runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyBoostInferOutput,
|
||||||
op->primitive()->name(), false);
|
op->primitive()->name(), false);
|
||||||
(op->input_abs_.emplace_back(ConvertAbstract(args)), ...);
|
(op->input_abs_.emplace_back(ConvertAbstract(args)), ...);
|
||||||
|
@ -174,26 +143,6 @@ class BACKEND_EXPORT OpRunner : public std::enable_shared_from_this<OpRunner> {
|
||||||
abstract_cache_.Push(op->output_abs_);
|
abstract_cache_.Push(op->output_abs_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// A static function used for the "customize" operator to generate the operator's output Tensor for grad op.
|
|
||||||
static void InferOpOutput(const std::shared_ptr<OpRunner> &op, OpRunnerInfo *op_runner_info) {
|
|
||||||
MS_EXCEPTION_IF_NULL(op_runner_info);
|
|
||||||
runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyBoostInferOutput,
|
|
||||||
op->primitive()->name(), false);
|
|
||||||
if (op_runner_info->inputs_abs.empty()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Get empty input abstract";
|
|
||||||
}
|
|
||||||
op->input_abs_ = op_runner_info->inputs_abs;
|
|
||||||
if (op_runner_info->output_abs == nullptr) {
|
|
||||||
op->output_abs_ = PyBoostUtils::InferByOpDef(op->primitive(), op->input_abs_);
|
|
||||||
MS_EXCEPTION_IF_NULL(op->output_abs_);
|
|
||||||
op_runner_info->output_abs = op->output_abs_;
|
|
||||||
} else {
|
|
||||||
op->output_abs_ = op_runner_info->output_abs;
|
|
||||||
}
|
|
||||||
PyBoostUtils::CreateOutputTensor(op->output_abs_, &op->outputs_);
|
|
||||||
abstract_cache_.Push(op->output_abs_);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Op primitive, may delete latter.
|
// Op primitive, may delete latter.
|
||||||
PrimitivePtr primitive_{nullptr};
|
PrimitivePtr primitive_{nullptr};
|
||||||
|
|
|
@ -29,7 +29,7 @@ class BACKEND_EXPORT ${op_name} : public pyboost::OpRunner {
|
||||||
: OpRunner(std::move(primitive), device_context) {}
|
: OpRunner(std::move(primitive), device_context) {}
|
||||||
~${op_name}() override = default;
|
~${op_name}() override = default;
|
||||||
|
|
||||||
virtual ${return_type} Call(${call_args}, OpRunnerInfo *op_run_info = nullptr) = 0;
|
virtual ${return_type} Call(${call_args}) = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
static const std::string &op_name() {return op_name_;}
|
static const std::string &op_name() {return op_name_;}
|
||||||
|
|
|
@ -132,7 +132,7 @@ NodePtr FuncBuilder::EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs)
|
||||||
<< PyNativeAlgo::Common::PrintDebugInfo(op_inputs);
|
<< PyNativeAlgo::Common::PrintDebugInfo(op_inputs);
|
||||||
MS_LOG(DEBUG) << "Get input abs size " << input_abs.size() << ", " << PyNativeAlgo::Common::PrintDebugInfo(input_abs);
|
MS_LOG(DEBUG) << "Get input abs size " << input_abs.size() << ", " << PyNativeAlgo::Common::PrintDebugInfo(input_abs);
|
||||||
VectorRef outputs;
|
VectorRef outputs;
|
||||||
kernel::pyboost::OpRunnerInfo op_runner_info{prim, device_target_, op_inputs, input_abs, input_mask, nullptr};
|
runtime::OpRunnerInfo op_runner_info{prim, device_target_, op_inputs, input_abs, input_mask, nullptr};
|
||||||
runtime::PyBoostOpExecute::GetInstance().Execute(&op_runner_info, &outputs);
|
runtime::PyBoostOpExecute::GetInstance().Execute(&op_runner_info, &outputs);
|
||||||
auto real_outputs = common::AnfAlgo::TransformVectorRefToMultiValue(outputs);
|
auto real_outputs = common::AnfAlgo::TransformVectorRefToMultiValue(outputs);
|
||||||
MS_LOG(DEBUG) << "Get output value size " << real_outputs.size() << ", "
|
MS_LOG(DEBUG) << "Get output value size " << real_outputs.size() << ", "
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "kernel/pyboost/auto_generate/add.h"
|
|
||||||
#include "include/common/utils/primitive_utils.h"
|
#include "include/common/utils/primitive_utils.h"
|
||||||
#include "pipeline/pynative/pynative_utils.h"
|
#include "pipeline/pynative/pynative_utils.h"
|
||||||
#include "ops/framework_ops.h"
|
#include "ops/framework_ops.h"
|
||||||
|
|
|
@ -515,14 +515,14 @@ bool GeKernelExecutor::ExecuteKernelTask(const runtime::KernelTaskType &task_typ
|
||||||
if (input_addr_list.size() != kCopyTaskInputsNum) {
|
if (input_addr_list.size() != kCopyTaskInputsNum) {
|
||||||
MS_LOG(EXCEPTION) << "input_addr_list.size() is invalid, input_addr_list.size():" << input_addr_list.size();
|
MS_LOG(EXCEPTION) << "input_addr_list.size() is invalid, input_addr_list.size():" << input_addr_list.size();
|
||||||
}
|
}
|
||||||
kernel::pyboost::CustomizeCopyAscend(device_context_, input_addr_list[1], input_addr_list[0], stream_id, nullptr);
|
kernel::pyboost::CustomizeCopyAscend(device_context_, input_addr_list[1], input_addr_list[0], stream_id);
|
||||||
} else {
|
} else {
|
||||||
// For contiguous task, there must be at least one input and one output.
|
// For contiguous task, there must be at least one input and one output.
|
||||||
if (input_addr_list.empty() || output_addr_list.empty()) {
|
if (input_addr_list.empty() || output_addr_list.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "input_addr_list.size() or output_addr_list.size() is invalid, input_addr_list.size():"
|
MS_LOG(EXCEPTION) << "input_addr_list.size() or output_addr_list.size() is invalid, input_addr_list.size():"
|
||||||
<< input_addr_list.size() << ", output_addr_list.size():" << output_addr_list.size();
|
<< input_addr_list.size() << ", output_addr_list.size():" << output_addr_list.size();
|
||||||
}
|
}
|
||||||
kernel::pyboost::CustomizeCopyAscend(device_context_, input_addr_list[0], output_addr_list[0], stream_id, nullptr);
|
kernel::pyboost::CustomizeCopyAscend(device_context_, input_addr_list[0], output_addr_list[0], stream_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -25,12 +25,7 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr AddAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x_tensor,
|
tensor::TensorPtr AddAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x_tensor,
|
||||||
const TensorPtr &y_tensor, OpRunnerInfo *op_runner_info) {
|
const TensorPtr &y_tensor) {
|
||||||
if (op_runner_info != nullptr) {
|
|
||||||
OpRunner::InferOpOutput(op, op_runner_info);
|
|
||||||
} else {
|
|
||||||
OpRunner::InferOpOutput(op, x_tensor, y_tensor);
|
|
||||||
}
|
|
||||||
OpRunner::InferOpOutput(op, x_tensor, y_tensor);
|
OpRunner::InferOpOutput(op, x_tensor, y_tensor);
|
||||||
// No need to convert input
|
// No need to convert input
|
||||||
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), x_tensor, y_tensor);
|
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), x_tensor, y_tensor);
|
||||||
|
|
|
@ -27,7 +27,7 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr AddAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x_tensor,
|
tensor::TensorPtr AddAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x_tensor,
|
||||||
const TensorPtr &y_tensor, OpRunnerInfo *op_runner_info);
|
const TensorPtr &y_tensor);
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -30,14 +30,8 @@ namespace pyboost {
|
||||||
std::tuple<tensor::TensorPtr, tensor::TensorPtr> ArgMaxWithValueAscendCustomize(const std::shared_ptr<OpRunner> &op,
|
std::tuple<tensor::TensorPtr, tensor::TensorPtr> ArgMaxWithValueAscendCustomize(const std::shared_ptr<OpRunner> &op,
|
||||||
const TensorPtr &input_tensor,
|
const TensorPtr &input_tensor,
|
||||||
const Int64ImmPtr &axis,
|
const Int64ImmPtr &axis,
|
||||||
const BoolImmPtr &keep_dims,
|
const BoolImmPtr &keep_dims) {
|
||||||
OpRunnerInfo *op_runner_info) {
|
OpRunner::InferOpOutput(op, input_tensor, axis, keep_dims);
|
||||||
if (op_runner_info != nullptr) {
|
|
||||||
OpRunner::InferOpOutput(op, op_runner_info);
|
|
||||||
} else {
|
|
||||||
OpRunner::InferOpOutput(op, input_tensor, axis, keep_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert ValuePtr to c++ scalar
|
// Convert ValuePtr to c++ scalar
|
||||||
auto axis_imm = GetValue<int64_t>(axis);
|
auto axis_imm = GetValue<int64_t>(axis);
|
||||||
auto keep_dims_imm = GetValue<bool>(keep_dims);
|
auto keep_dims_imm = GetValue<bool>(keep_dims);
|
||||||
|
|
|
@ -31,8 +31,7 @@ namespace pyboost {
|
||||||
std::tuple<tensor::TensorPtr, tensor::TensorPtr> ArgMaxWithValueAscendCustomize(const std::shared_ptr<OpRunner> &op,
|
std::tuple<tensor::TensorPtr, tensor::TensorPtr> ArgMaxWithValueAscendCustomize(const std::shared_ptr<OpRunner> &op,
|
||||||
const TensorPtr &input_tensor,
|
const TensorPtr &input_tensor,
|
||||||
const Int64ImmPtr &axis,
|
const Int64ImmPtr &axis,
|
||||||
const BoolImmPtr &keep_dims,
|
const BoolImmPtr &keep_dims);
|
||||||
OpRunnerInfo *op_runner_info);
|
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -30,14 +30,8 @@ namespace pyboost {
|
||||||
std::tuple<tensor::TensorPtr, tensor::TensorPtr> ArgMinWithValueAscendCustomize(const std::shared_ptr<OpRunner> &op,
|
std::tuple<tensor::TensorPtr, tensor::TensorPtr> ArgMinWithValueAscendCustomize(const std::shared_ptr<OpRunner> &op,
|
||||||
const TensorPtr &input_tensor,
|
const TensorPtr &input_tensor,
|
||||||
const Int64ImmPtr &axis,
|
const Int64ImmPtr &axis,
|
||||||
const BoolImmPtr &keep_dims,
|
const BoolImmPtr &keep_dims) {
|
||||||
OpRunnerInfo *op_runner_info) {
|
OpRunner::InferOpOutput(op, input_tensor, axis, keep_dims);
|
||||||
if (op_runner_info != nullptr) {
|
|
||||||
OpRunner::InferOpOutput(op, op_runner_info);
|
|
||||||
} else {
|
|
||||||
OpRunner::InferOpOutput(op, input_tensor, axis, keep_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert ValuePtr to c++ scalar
|
// Convert ValuePtr to c++ scalar
|
||||||
auto axis_imm = GetValue<int64_t>(axis);
|
auto axis_imm = GetValue<int64_t>(axis);
|
||||||
auto keep_dims_imm = GetValue<bool>(keep_dims);
|
auto keep_dims_imm = GetValue<bool>(keep_dims);
|
||||||
|
|
|
@ -31,8 +31,7 @@ namespace pyboost {
|
||||||
std::tuple<tensor::TensorPtr, tensor::TensorPtr> ArgMinWithValueAscendCustomize(const std::shared_ptr<OpRunner> &op,
|
std::tuple<tensor::TensorPtr, tensor::TensorPtr> ArgMinWithValueAscendCustomize(const std::shared_ptr<OpRunner> &op,
|
||||||
const TensorPtr &input_tensor,
|
const TensorPtr &input_tensor,
|
||||||
const Int64ImmPtr &axis,
|
const Int64ImmPtr &axis,
|
||||||
const BoolImmPtr &keep_dims,
|
const BoolImmPtr &keep_dims);
|
||||||
OpRunnerInfo *op_runner_info);
|
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -23,8 +23,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr ContiguousAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
tensor::TensorPtr ContiguousAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor) {
|
||||||
OpRunnerInfo *op_runner_info) {
|
|
||||||
MS_LOG(DEBUG) << "Call start";
|
MS_LOG(DEBUG) << "Call start";
|
||||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||||
|
|
||||||
|
|
|
@ -27,8 +27,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr ContiguousAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
tensor::TensorPtr ContiguousAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor);
|
||||||
OpRunnerInfo *op_runner_info);
|
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -64,14 +64,8 @@ tensor::TensorPtr Conv2DAscendCall(const std::shared_ptr<OpRunner> &op, const de
|
||||||
tensor::TensorPtr Conv2DAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
tensor::TensorPtr Conv2DAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
||||||
const TensorPtr &weight_tensor, const std::optional<TensorPtr> &bias_tensor,
|
const TensorPtr &weight_tensor, const std::optional<TensorPtr> &bias_tensor,
|
||||||
const ValueTuplePtr &stride, const ValueTuplePtr &padding,
|
const ValueTuplePtr &stride, const ValueTuplePtr &padding,
|
||||||
const ValueTuplePtr &dilation, const Int64ImmPtr &groups,
|
const ValueTuplePtr &dilation, const Int64ImmPtr &groups) {
|
||||||
OpRunnerInfo *op_runner_info) {
|
OpRunner::InferOpOutput(op, input_tensor, weight_tensor, bias_tensor, stride, padding, dilation, groups);
|
||||||
if (op_runner_info != nullptr) {
|
|
||||||
OpRunner::InferOpOutput(op, op_runner_info);
|
|
||||||
} else {
|
|
||||||
OpRunner::InferOpOutput(op, input_tensor, weight_tensor, bias_tensor, stride, padding, dilation, groups);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert ValueTuple to std::vector
|
// Convert ValueTuple to std::vector
|
||||||
std::vector<int64_t> stride_vector = ConvertValueTupleToVector<int64_t>(stride);
|
std::vector<int64_t> stride_vector = ConvertValueTupleToVector<int64_t>(stride);
|
||||||
std::vector<int64_t> padding_vector = ConvertValueTupleToVector<int64_t>(padding);
|
std::vector<int64_t> padding_vector = ConvertValueTupleToVector<int64_t>(padding);
|
||||||
|
|
|
@ -30,8 +30,7 @@ namespace pyboost {
|
||||||
tensor::TensorPtr Conv2DAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
tensor::TensorPtr Conv2DAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
||||||
const TensorPtr &weight_tensor, const std::optional<TensorPtr> &bias_tensor,
|
const TensorPtr &weight_tensor, const std::optional<TensorPtr> &bias_tensor,
|
||||||
const ValueTuplePtr &stride, const ValueTuplePtr &padding,
|
const ValueTuplePtr &stride, const ValueTuplePtr &padding,
|
||||||
const ValueTuplePtr &dilation, const Int64ImmPtr &groups,
|
const ValueTuplePtr &dilation, const Int64ImmPtr &groups);
|
||||||
OpRunnerInfo *op_runner_info);
|
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -21,8 +21,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr CopyAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
tensor::TensorPtr CopyAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor) {
|
||||||
OpRunnerInfo *op_runner_info) {
|
|
||||||
MS_LOG(DEBUG) << "Call start";
|
MS_LOG(DEBUG) << "Call start";
|
||||||
auto input_abs = input_tensor->ToAbstract();
|
auto input_abs = input_tensor->ToAbstract();
|
||||||
input_abs->set_value(kValueAny);
|
input_abs->set_value(kValueAny);
|
||||||
|
|
|
@ -27,8 +27,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr CopyAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
tensor::TensorPtr CopyAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor);
|
||||||
OpRunnerInfo *op_runner_info);
|
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -23,8 +23,7 @@ namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
// Unconventional pyboost writing. Please do not refer to this to implement other operators!
|
// Unconventional pyboost writing. Please do not refer to this to implement other operators!
|
||||||
void CustomizeCopyAscend(device::DeviceContext *device_context, const device::DeviceAddressPtr &input_addr,
|
void CustomizeCopyAscend(device::DeviceContext *device_context, const device::DeviceAddressPtr &input_addr,
|
||||||
const device::DeviceAddressPtr &output_addr, const size_t &stream_id,
|
const device::DeviceAddressPtr &output_addr, const size_t &stream_id) {
|
||||||
OpRunnerInfo *op_runner_info) {
|
|
||||||
MS_LOG(DEBUG) << "Call start";
|
MS_LOG(DEBUG) << "Call start";
|
||||||
MS_EXCEPTION_IF_NULL(input_addr);
|
MS_EXCEPTION_IF_NULL(input_addr);
|
||||||
MS_EXCEPTION_IF_NULL(output_addr);
|
MS_EXCEPTION_IF_NULL(output_addr);
|
||||||
|
|
|
@ -28,8 +28,7 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
void CustomizeCopyAscend(device::DeviceContext *device_context, const device::DeviceAddressPtr &input_addr,
|
void CustomizeCopyAscend(device::DeviceContext *device_context, const device::DeviceAddressPtr &input_addr,
|
||||||
const device::DeviceAddressPtr &output_addr, const size_t &stream_id,
|
const device::DeviceAddressPtr &output_addr, const size_t &stream_id);
|
||||||
OpRunnerInfo *op_runner_info);
|
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -25,19 +25,14 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr GatherDGradAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x,
|
tensor::TensorPtr GatherDGradAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x,
|
||||||
const Int64ImmPtr dim, const TensorPtr &index, const TensorPtr &d_out,
|
const Int64ImmPtr dim, const TensorPtr &index, const TensorPtr &d_out) {
|
||||||
OpRunnerInfo *op_runner_info) {
|
|
||||||
MS_EXCEPTION_IF_NULL(dim);
|
MS_EXCEPTION_IF_NULL(dim);
|
||||||
MS_EXCEPTION_IF_NULL(op);
|
MS_EXCEPTION_IF_NULL(op);
|
||||||
MS_EXCEPTION_IF_NULL(x);
|
MS_EXCEPTION_IF_NULL(x);
|
||||||
MS_EXCEPTION_IF_NULL(index);
|
MS_EXCEPTION_IF_NULL(index);
|
||||||
MS_EXCEPTION_IF_NULL(d_out);
|
MS_EXCEPTION_IF_NULL(d_out);
|
||||||
if (op_runner_info != nullptr) {
|
|
||||||
OpRunner::InferOpOutput(op, op_runner_info);
|
|
||||||
} else {
|
|
||||||
OpRunner::InferOpOutput(op, x, dim, index, d_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
OpRunner::InferOpOutput(op, x, dim, index, d_out);
|
||||||
auto dim_value = dim->value();
|
auto dim_value = dim->value();
|
||||||
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), d_out);
|
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), d_out);
|
||||||
PyBoostUtils::PrepareOpOutputs(op->device_context(), op->stream_id(), op->outputs());
|
PyBoostUtils::PrepareOpOutputs(op->device_context(), op->stream_id(), op->outputs());
|
||||||
|
|
|
@ -26,8 +26,7 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr GatherDGradAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x,
|
tensor::TensorPtr GatherDGradAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x,
|
||||||
const Int64ImmPtr dim, const TensorPtr &index, const TensorPtr &d_out,
|
const Int64ImmPtr dim, const TensorPtr &index, const TensorPtr &d_out);
|
||||||
OpRunnerInfo *op_runner_info);
|
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -24,14 +24,8 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr GeLUGradAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &dy_tensor,
|
tensor::TensorPtr GeLUGradAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &dy_tensor,
|
||||||
const TensorPtr &x_tensor, const TensorPtr &y_tensor,
|
const TensorPtr &x_tensor, const TensorPtr &y_tensor) {
|
||||||
OpRunnerInfo *op_runner_info) {
|
OpRunner::InferOpOutput(op, dy_tensor, x_tensor, y_tensor);
|
||||||
if (op_runner_info != nullptr) {
|
|
||||||
OpRunner::InferOpOutput(op, op_runner_info);
|
|
||||||
} else {
|
|
||||||
OpRunner::InferOpOutput(op, dy_tensor, x_tensor, y_tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create device address for input/output tensors
|
// Create device address for input/output tensors
|
||||||
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), dy_tensor, x_tensor, y_tensor);
|
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), dy_tensor, x_tensor, y_tensor);
|
||||||
PyBoostUtils::PrepareOpOutputs(op->device_context(), op->stream_id(), op->outputs());
|
PyBoostUtils::PrepareOpOutputs(op->device_context(), op->stream_id(), op->outputs());
|
||||||
|
|
|
@ -28,8 +28,7 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr GeLUGradAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &dy_tensor,
|
tensor::TensorPtr GeLUGradAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &dy_tensor,
|
||||||
const TensorPtr &x_tensor, const TensorPtr &y_tensor,
|
const TensorPtr &x_tensor, const TensorPtr &y_tensor);
|
||||||
OpRunnerInfo *op_runner_info);
|
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -140,13 +140,8 @@ void IdentityCustomizeCall(const std::shared_ptr<OpRunner> &op, const TensorPtr
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor::TensorPtr IdentityAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x_tensor,
|
tensor::TensorPtr IdentityAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x_tensor) {
|
||||||
OpRunnerInfo *op_runner_info) {
|
OpRunner::InferOpOutput(op, x_tensor);
|
||||||
if (op_runner_info != nullptr) {
|
|
||||||
OpRunner::InferOpOutput(op, op_runner_info);
|
|
||||||
} else {
|
|
||||||
OpRunner::InferOpOutput(op, x_tensor);
|
|
||||||
}
|
|
||||||
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), x_tensor);
|
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), x_tensor);
|
||||||
PyBoostUtils::PrepareOpOutputs(op->device_context(), op->stream_id(), op->outputs());
|
PyBoostUtils::PrepareOpOutputs(op->device_context(), op->stream_id(), op->outputs());
|
||||||
FillHostInfoForAclOp(x_tensor);
|
FillHostInfoForAclOp(x_tensor);
|
||||||
|
|
|
@ -26,8 +26,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr IdentityAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x_tensor,
|
tensor::TensorPtr IdentityAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x_tensor);
|
||||||
OpRunnerInfo *op_runner_info);
|
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -35,13 +35,8 @@ tensor::TensorPtr MaskedFillAscendCall(const std::shared_ptr<OpRunner> &op, cons
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
tensor::TensorPtr MaskedFillAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
tensor::TensorPtr MaskedFillAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
||||||
const TensorPtr &mask_tensor, const TensorPtr &value_tensor,
|
const TensorPtr &mask_tensor, const TensorPtr &value_tensor) {
|
||||||
OpRunnerInfo *op_runner_info) {
|
OpRunner::InferOpOutput(op, input_tensor, mask_tensor, value_tensor);
|
||||||
if (op_runner_info != nullptr) {
|
|
||||||
OpRunner::InferOpOutput(op, op_runner_info);
|
|
||||||
} else {
|
|
||||||
OpRunner::InferOpOutput(op, input_tensor, mask_tensor, value_tensor);
|
|
||||||
}
|
|
||||||
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), input_tensor, mask_tensor, value_tensor);
|
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), input_tensor, mask_tensor, value_tensor);
|
||||||
PyBoostUtils::PrepareOpOutputs(op->device_context(), op->stream_id(), op->outputs());
|
PyBoostUtils::PrepareOpOutputs(op->device_context(), op->stream_id(), op->outputs());
|
||||||
// Async
|
// Async
|
||||||
|
|
|
@ -28,8 +28,7 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr MaskedFillAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
tensor::TensorPtr MaskedFillAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
||||||
const TensorPtr &mask_tensor, const TensorPtr &value_tensor,
|
const TensorPtr &mask_tensor, const TensorPtr &value_tensor);
|
||||||
OpRunnerInfo *op_runner_info);
|
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -26,7 +26,7 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr ReshapeAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
tensor::TensorPtr ReshapeAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
||||||
const ValueTuplePtr &shape, OpRunnerInfo *op_runner_info) {
|
const ValueTuplePtr &shape) {
|
||||||
MS_LOG(DEBUG) << "Call start";
|
MS_LOG(DEBUG) << "Call start";
|
||||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr ReshapeAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
tensor::TensorPtr ReshapeAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
||||||
const ValueTuplePtr &shape, OpRunnerInfo *op_runner_info);
|
const ValueTuplePtr &shape);
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -32,13 +32,8 @@ void SigmoidGradAscendCall(const std::shared_ptr<OpRunner> &op, const device::De
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
tensor::TensorPtr SigmoidGradAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &y_tensor,
|
tensor::TensorPtr SigmoidGradAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &y_tensor,
|
||||||
const TensorPtr &dy_tensor, OpRunnerInfo *op_runner_info) {
|
const TensorPtr &dy_tensor) {
|
||||||
if (op_runner_info != nullptr) {
|
OpRunner::InferOpOutput(op, dy_tensor, y_tensor);
|
||||||
OpRunner::InferOpOutput(op, op_runner_info);
|
|
||||||
} else {
|
|
||||||
OpRunner::InferOpOutput(op, dy_tensor, y_tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create device address for input/output tensors
|
// Create device address for input/output tensors
|
||||||
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), dy_tensor, y_tensor);
|
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), dy_tensor, y_tensor);
|
||||||
PyBoostUtils::PrepareOpOutputs(op->device_context(), op->stream_id(), op->outputs());
|
PyBoostUtils::PrepareOpOutputs(op->device_context(), op->stream_id(), op->outputs());
|
||||||
|
|
|
@ -28,7 +28,7 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr SigmoidGradAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &y_tensor,
|
tensor::TensorPtr SigmoidGradAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &y_tensor,
|
||||||
const TensorPtr &dy_tensor, OpRunnerInfo *op_runner_info);
|
const TensorPtr &dy_tensor);
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -35,13 +35,8 @@ void SoftmaxAscendCall(const std::shared_ptr<OpRunner> &op, const device::Device
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
tensor::TensorPtr SoftmaxAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &logits_tensor,
|
tensor::TensorPtr SoftmaxAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &logits_tensor,
|
||||||
const ValueTuplePtr &axis, OpRunnerInfo *op_runner_info) {
|
const ValueTuplePtr &axis) {
|
||||||
if (op_runner_info != nullptr) {
|
OpRunner::InferOpOutput(op, logits_tensor, axis);
|
||||||
OpRunner::InferOpOutput(op, op_runner_info);
|
|
||||||
} else {
|
|
||||||
OpRunner::InferOpOutput(op, logits_tensor, axis);
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValueTuple to std::vector
|
// ValueTuple to std::vector
|
||||||
auto axis_vector = ConvertValueTupleToVector<int64_t>(axis);
|
auto axis_vector = ConvertValueTupleToVector<int64_t>(axis);
|
||||||
auto dim = axis_vector[0];
|
auto dim = axis_vector[0];
|
||||||
|
|
|
@ -28,7 +28,7 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr SoftmaxAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &logits_tensor,
|
tensor::TensorPtr SoftmaxAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &logits_tensor,
|
||||||
const ValueTuplePtr &axis, OpRunnerInfo *op_runner_info);
|
const ValueTuplePtr &axis);
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -36,14 +36,8 @@ void SquareAscendCall(const std::shared_ptr<OpRunner> &op, const device::DeviceC
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
tensor::TensorPtr SquareAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x_tensor,
|
tensor::TensorPtr SquareAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x_tensor) {
|
||||||
OpRunnerInfo *op_runner_info) {
|
OpRunner::InferOpOutput(op, x_tensor);
|
||||||
if (op_runner_info != nullptr) {
|
|
||||||
OpRunner::InferOpOutput(op, op_runner_info);
|
|
||||||
} else {
|
|
||||||
OpRunner::InferOpOutput(op, x_tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
// No need to convert input
|
// No need to convert input
|
||||||
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), x_tensor);
|
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), x_tensor);
|
||||||
PyBoostUtils::PrepareOpOutputs(op->device_context(), op->stream_id(), op->outputs());
|
PyBoostUtils::PrepareOpOutputs(op->device_context(), op->stream_id(), op->outputs());
|
||||||
|
|
|
@ -27,8 +27,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr SquareAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x_tensor,
|
tensor::TensorPtr SquareAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &x_tensor);
|
||||||
OpRunnerInfo *op_runner_info);
|
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -26,13 +26,9 @@
|
||||||
|
|
||||||
namespace mindspore::kernel::pyboost {
|
namespace mindspore::kernel::pyboost {
|
||||||
void TileAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_x_tensor,
|
void TileAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_x_tensor,
|
||||||
const ValueTuplePtr &dims, OpRunnerInfo *op_runner_info) {
|
const ValueTuplePtr &dims) {
|
||||||
MS_EXCEPTION_IF_NULL(op);
|
MS_EXCEPTION_IF_NULL(op);
|
||||||
if (op_runner_info != nullptr) {
|
OpRunner::InferOpOutput(op, input_x_tensor, dims);
|
||||||
OpRunner::InferOpOutput(op, op_runner_info);
|
|
||||||
} else {
|
|
||||||
OpRunner::InferOpOutput(op, input_x_tensor, dims);
|
|
||||||
}
|
|
||||||
std::vector<int64_t> multiples_vector = ConvertValueTupleToVector<int64_t>(dims);
|
std::vector<int64_t> multiples_vector = ConvertValueTupleToVector<int64_t>(dims);
|
||||||
|
|
||||||
// Expand dims with 1 in head when its length is less than x rank.
|
// Expand dims with 1 in head when its length is less than x rank.
|
||||||
|
|
|
@ -25,6 +25,6 @@
|
||||||
|
|
||||||
namespace mindspore::kernel::pyboost {
|
namespace mindspore::kernel::pyboost {
|
||||||
void TileAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
void TileAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
||||||
const ValueTuplePtr &multiples, OpRunnerInfo *op_runner_info);
|
const ValueTuplePtr &multiples);
|
||||||
} // namespace mindspore::kernel::pyboost
|
} // namespace mindspore::kernel::pyboost
|
||||||
#endif // MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_PYBOOST_CUSTOMIZE_TILE_H_
|
#endif // MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_PYBOOST_CUSTOMIZE_TILE_H_
|
||||||
|
|
|
@ -34,14 +34,9 @@ tensor::TensorPtr UpsampleNearest1dAscendCall(const std::shared_ptr<OpRunner> &o
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
tensor::TensorPtr UpsampleNearest1dAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
tensor::TensorPtr UpsampleNearest1dAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
||||||
const ValueTuplePtr &output_size, const ValueTuplePtr &scale_factors,
|
const ValueTuplePtr &output_size,
|
||||||
OpRunnerInfo *op_runner_info) {
|
const ValueTuplePtr &scale_factors) {
|
||||||
if (op_runner_info != nullptr) {
|
OpRunner::InferOpOutput(op, input_tensor, output_size, scale_factors);
|
||||||
OpRunner::InferOpOutput(op, op_runner_info);
|
|
||||||
} else {
|
|
||||||
OpRunner::InferOpOutput(op, input_tensor, output_size, scale_factors);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64_t> output_size_vector = ConvertValueTupleToVector<int64_t>(output_size);
|
std::vector<int64_t> output_size_vector = ConvertValueTupleToVector<int64_t>(output_size);
|
||||||
|
|
||||||
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), input_tensor);
|
PyBoostUtils::PrepareOpInputs(op->device_context(), op->stream_id(), input_tensor);
|
||||||
|
|
|
@ -28,8 +28,8 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
tensor::TensorPtr UpsampleNearest1dAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
tensor::TensorPtr UpsampleNearest1dAscendCustomize(const std::shared_ptr<OpRunner> &op, const TensorPtr &input_tensor,
|
||||||
const ValueTuplePtr &output_size, const ValueTuplePtr &scale_factors,
|
const ValueTuplePtr &output_size,
|
||||||
OpRunnerInfo *op_runner_info);
|
const ValueTuplePtr &scale_factors);
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -30,7 +30,7 @@ class ${op_name}Ascend : public pyboost::${op_name} {
|
||||||
: ${op_name}(std::move(primitive), device_context) {}
|
: ${op_name}(std::move(primitive), device_context) {}
|
||||||
~${op_name}Ascend() = default;
|
~${op_name}Ascend() = default;
|
||||||
|
|
||||||
${return_type} Call(${call_args_with_type}, OpRunnerInfo * op_runner_info = nullptr) override;
|
${return_type} Call(${call_args_with_type}) override;
|
||||||
};
|
};
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -23,7 +23,7 @@ ${customize_include}
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
${return_type} ${op_name}Ascend::Call(${call_args_with_type}, OpRunnerInfo * op_runner_info) {
|
${return_type} ${op_name}Ascend::Call(${call_args_with_type}) {
|
||||||
${call_impl}
|
${call_impl}
|
||||||
}
|
}
|
||||||
MS_REG_PYBOOST_OP(Ascend, ${op_name});
|
MS_REG_PYBOOST_OP(Ascend, ${op_name});
|
||||||
|
|
|
@ -1,9 +1,5 @@
|
||||||
MS_LOG(DEBUG) << op_name() << " call start";
|
MS_LOG(DEBUG) << op_name() << " call start";
|
||||||
if (op_runner_info != nullptr) {
|
InferOutput(${call_args});
|
||||||
InferOutput(op_runner_info);
|
|
||||||
} else {
|
|
||||||
InferOutput(${call_args});
|
|
||||||
}
|
|
||||||
// ValueTuple to std::vector
|
// ValueTuple to std::vector
|
||||||
${value_tuple_convert}
|
${value_tuple_convert}
|
||||||
// Convert ValuePtr to c++ scalar
|
// Convert ValuePtr to c++ scalar
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
${customize_func}(get_op(), ${call_args}, op_runner_info);
|
${customize_func}(get_op(), ${call_args});
|
||||||
return ${return_values};
|
return ${return_values};
|
|
@ -1,9 +1,5 @@
|
||||||
MS_LOG(DEBUG) << op_name() << " call start";
|
MS_LOG(DEBUG) << op_name() << " call start";
|
||||||
if (op_runner_info != nullptr) {
|
InferOutput(${call_args});
|
||||||
InferOutput(op_runner_info);
|
|
||||||
} else {
|
|
||||||
InferOutput(${call_args});
|
|
||||||
}
|
|
||||||
|
|
||||||
${tensor_list_convert}
|
${tensor_list_convert}
|
||||||
MS_EXCEPTION_IF_NULL(primitive());
|
MS_EXCEPTION_IF_NULL(primitive());
|
||||||
|
|
|
@ -30,7 +30,7 @@ class ${op_name}CPU : public pyboost::${op_name} {
|
||||||
: ${op_name}(std::move(primitive), device_context) {}
|
: ${op_name}(std::move(primitive), device_context) {}
|
||||||
~${op_name}CPU() = default;
|
~${op_name}CPU() = default;
|
||||||
|
|
||||||
${return_type} Call(${call_args_with_type}, OpRunnerInfo * op_runner_info = nullptr) override;
|
${return_type} Call(${call_args_with_type}) override;
|
||||||
};
|
};
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -21,7 +21,7 @@ ${customize_include}
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
${return_type} ${op_name}CPU::Call(${call_args_with_type}, OpRunnerInfo * op_runner_info) {
|
${return_type} ${op_name}CPU::Call(${call_args_with_type}) {
|
||||||
${call_impl}
|
${call_impl}
|
||||||
}
|
}
|
||||||
MS_REG_PYBOOST_OP(CPU, ${op_name});
|
MS_REG_PYBOOST_OP(CPU, ${op_name});
|
||||||
|
|
|
@ -1,9 +1,5 @@
|
||||||
MS_LOG(DEBUG) << op_name() << " call start";
|
MS_LOG(DEBUG) << op_name() << " call start";
|
||||||
if (op_runner_info != nullptr) {
|
InferOutput(${call_args});
|
||||||
InferOutput(op_runner_info);
|
|
||||||
} else {
|
|
||||||
InferOutput(${call_args});
|
|
||||||
}
|
|
||||||
|
|
||||||
${tensor_list_convert}
|
${tensor_list_convert}
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ class ${op_name}GPU : public pyboost::${op_name} {
|
||||||
: ${op_name}(std::move(primitive), device_context) {}
|
: ${op_name}(std::move(primitive), device_context) {}
|
||||||
~${op_name}GPU() = default;
|
~${op_name}GPU() = default;
|
||||||
|
|
||||||
${return_type} Call(${call_args_with_type}, OpRunnerInfo * op_runner_info = nullptr) override;
|
${return_type} Call(${call_args_with_type}) override;
|
||||||
};
|
};
|
||||||
} // namespace pyboost
|
} // namespace pyboost
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -22,7 +22,7 @@ ${customize_include}
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace pyboost {
|
namespace pyboost {
|
||||||
${return_type} ${op_name}GPU::Call(${call_args_with_type}, OpRunnerInfo * op_runner_info) {
|
${return_type} ${op_name}GPU::Call(${call_args_with_type}) {
|
||||||
${call_impl}
|
${call_impl}
|
||||||
}
|
}
|
||||||
MS_REG_PYBOOST_OP(GPU, ${op_name});
|
MS_REG_PYBOOST_OP(GPU, ${op_name});
|
||||||
|
|
|
@ -106,7 +106,7 @@ void ChildAtFork() {
|
||||||
MS_LOG(DEBUG) << "Release GIL lock acquired manually before fork.";
|
MS_LOG(DEBUG) << "Release GIL lock acquired manually before fork.";
|
||||||
PyGILState_Release(static_cast<PyGILState_STATE>(ForkUtils::GetInstance().GetGilState()));
|
PyGILState_Release(static_cast<PyGILState_STATE>(ForkUtils::GetInstance().GetGilState()));
|
||||||
}
|
}
|
||||||
device::DeviceContextManager::GetInstance().ChildAfterFork();
|
|
||||||
// Trigger ChildAfterFork callbacks in child process.
|
// Trigger ChildAfterFork callbacks in child process.
|
||||||
ForkUtils::GetInstance().ChildAtFork();
|
ForkUtils::GetInstance().ChildAtFork();
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,11 +21,11 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "kernel/pyboost/op_runner.h"
|
#include "kernel/pyboost/op_runner.h"
|
||||||
|
#include "runtime/pynative/op_runner.h"
|
||||||
#include "runtime/pynative/op_function/func_object.h"
|
#include "runtime/pynative/op_function/func_object.h"
|
||||||
#include "backend/graph_compiler/backend.h"
|
#include "backend/graph_compiler/backend.h"
|
||||||
|
|
||||||
namespace mindspore::runtime {
|
namespace mindspore::runtime {
|
||||||
using OpRunnerInfo = kernel::pyboost::OpRunnerInfo;
|
|
||||||
using Func = std::function<void(OpRunnerInfo *, VectorRef *)>;
|
using Func = std::function<void(OpRunnerInfo *, VectorRef *)>;
|
||||||
|
|
||||||
class PyBoostOpExecute {
|
class PyBoostOpExecute {
|
||||||
|
|
|
@ -6,8 +6,10 @@ void ${func_name}(OpRunnerInfo* op_runner_info, VectorRef *op_outputs) {
|
||||||
|
|
||||||
// Run op
|
// Run op
|
||||||
${convert_body}
|
${convert_body}
|
||||||
(void)op->Call(${call_args}, op_runner_info);
|
(void)op->Call(${call_args});
|
||||||
|
op_runner_info->output_abs = op->output_abs();
|
||||||
MS_EXCEPTION_IF_NULL(op_outputs);
|
MS_EXCEPTION_IF_NULL(op_outputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(op_runner_info->output_abs);
|
||||||
(void)std::transform(op->outputs().begin(), op->outputs().end(), std::back_inserter(*op_outputs),
|
(void)std::transform(op->outputs().begin(), op->outputs().end(), std::back_inserter(*op_outputs),
|
||||||
[] (const auto &item) {return item;});
|
[] (const auto &item) {return item;});
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
#include "runtime/pynative/op_function/pyboost_grad_functions.h"
|
#include "runtime/pynative/op_function/pyboost_grad_functions.h"
|
||||||
#include "runtime/pynative/op_executor.h"
|
#include "runtime/pynative/op_executor.h"
|
||||||
#include "runtime/pynative/op_function/value_converter.h"
|
#include "runtime/pynative/op_function/value_converter.h"
|
||||||
#include "kernel/pyboost/py_boost_utils.h"
|
#include "kernel/pyboost/pyboost_utils.h"
|
||||||
#include "runtime/pynative/op_function/pyboost_grad_functions.h"
|
#include "runtime/pynative/op_function/pyboost_grad_functions.h"
|
||||||
#include "backend/graph_compiler/vmimpl.h"
|
#include "backend/graph_compiler/vmimpl.h"
|
||||||
#include "include/common/utils/python_adapter.h"
|
#include "include/common/utils/python_adapter.h"
|
||||||
|
|
|
@ -24,6 +24,15 @@
|
||||||
#include "runtime/hardware/device_context.h"
|
#include "runtime/hardware/device_context.h"
|
||||||
|
|
||||||
namespace mindspore::runtime {
|
namespace mindspore::runtime {
|
||||||
|
struct OpRunnerInfo {
|
||||||
|
const PrimitivePtr &prim;
|
||||||
|
const std::string &device_target;
|
||||||
|
const vector<ValuePtr> &inputs;
|
||||||
|
const abstract::AbstractBasePtrList &inputs_abs;
|
||||||
|
const std::vector<InputType> &inputs_mask;
|
||||||
|
abstract::AbstractBasePtr output_abs;
|
||||||
|
};
|
||||||
|
|
||||||
class OpRunner {
|
class OpRunner {
|
||||||
public:
|
public:
|
||||||
// Update Tensor or input node DeviceAddress before PyNative async running.
|
// Update Tensor or input node DeviceAddress before PyNative async running.
|
||||||
|
|
Loading…
Reference in New Issue