add dvm kernel module
This commit is contained in:
parent
dc3f448bc7
commit
a51ee7aff0
|
@ -0,0 +1,2 @@
|
|||
mindspore/ccsrc/plugin/device/ascend/kernel/dvm/prebuild/aarch64/libdvm.a filter=lfs diff=lfs merge=lfs -text
|
||||
mindspore/ccsrc/plugin/device/ascend/kernel/dvm/prebuild/x86_64/libdvm.a filter=lfs diff=lfs merge=lfs -text
|
|
@ -397,7 +397,7 @@ static std::unordered_set<std::string> kAclnnEnableList;
|
|||
bool ReadAclnnEnableEnv(const AnfNodePtr &node) {
|
||||
static auto enable_aclnn_env = common::GetEnv("MS_ENABLE_ACLNN");
|
||||
if (enable_aclnn_env == "1") {
|
||||
return kernel::IsRegisteredAclnnOp(node);
|
||||
return kernel::IsRegisteredAclnnOp(node) ? true : false;
|
||||
}
|
||||
|
||||
static auto read_config = !enable_aclnn_env.empty() && enable_aclnn_env != "0";
|
||||
|
@ -417,7 +417,7 @@ bool ReadAclnnEnableEnv(const AnfNodePtr &node) {
|
|||
|
||||
std::string op_name = common::AnfAlgo::GetCNodeName(node);
|
||||
if (kAclnnEnableList.count(op_name) != 0) {
|
||||
return kernel::IsRegisteredAclnnOp(node);
|
||||
return kernel::IsRegisteredAclnnOp(node) ? true : false;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -642,6 +642,26 @@ void SetKernelInfoBeforeCreateKernel(const std::vector<CNodePtr> &nodes) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
class AscendGraphKernelInfo : public GraphKernelInfo {
|
||||
public:
|
||||
AscendGraphKernelInfo() = default;
|
||||
virtual ~AscendGraphKernelInfo() = default;
|
||||
void SetKernelInfo(const CNodePtr &kernel_node, KernelType) override {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
const auto &kernel_graph = AnfAlgo::FetchKernelGraph(kernel_node.get());
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto enable_aclnn = IsEnableAclNN(kernel_graph, kernel_node);
|
||||
auto [select_res, msg, etype] = SelectKernelInfoWithMsg(kernel_node, enable_aclnn);
|
||||
if (select_res) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "node is " << kernel_node->fullname_with_scope() << " should backoff";
|
||||
std::pair<std::string, ExceptionType> failure_info = std::make_pair(msg, etype);
|
||||
HandleKernelSelectFailure(kernel_graph, kernel_node, failure_info);
|
||||
}
|
||||
};
|
||||
REG_GRAPH_KERNEL_INFO(kAscendDevice, AscendGraphKernelInfo);
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
approvers:
|
||||
- gaoxiong1 #
|
||||
- ckey_dou
|
||||
- dayschan
|
||||
- anyrenwei
|
||||
- zichun_ye
|
||||
|
||||
options:
|
||||
no_parent_owners: true
|
|
@ -0,0 +1,140 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef _DVM_H_
|
||||
#define _DVM_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
namespace dvm {
|
||||
enum DType {
|
||||
kInt8 = 0,
|
||||
kFloat16,
|
||||
kFloat32,
|
||||
kInt32,
|
||||
kTypeEnd,
|
||||
};
|
||||
|
||||
enum UnaryOpType {
|
||||
kSqrt = 0,
|
||||
kRsqrt,
|
||||
kAbs,
|
||||
kLog,
|
||||
kExp,
|
||||
kReciprocal,
|
||||
kIsFinite,
|
||||
kLogicalNot,
|
||||
kUnaryOpEnd,
|
||||
};
|
||||
|
||||
enum BinaryOpType {
|
||||
kEqual = 0,
|
||||
kNotEqual,
|
||||
kGreater,
|
||||
kGreaterEqual,
|
||||
kLess,
|
||||
kLessEqual,
|
||||
kAdd,
|
||||
kSub,
|
||||
kMul,
|
||||
kDiv,
|
||||
kPow,
|
||||
kMaximum,
|
||||
kMinimum,
|
||||
kLogicalAnd,
|
||||
kLogicalOr,
|
||||
kBinaryOpEnd,
|
||||
};
|
||||
|
||||
enum ReduceOpType {
|
||||
kSum = 0,
|
||||
kReduceOpEnd,
|
||||
};
|
||||
|
||||
enum KernelType {
|
||||
kStaticShape = 0,
|
||||
kDynShape,
|
||||
kStaticParallel,
|
||||
kEager,
|
||||
kKernelTmplEnd,
|
||||
};
|
||||
|
||||
class NDObject;
|
||||
class VKernel;
|
||||
|
||||
struct ShapeRef {
|
||||
ShapeRef() {}
|
||||
explicit ShapeRef(const std::vector<int64_t> &other) : data(other.data()), size(other.size()) {}
|
||||
ShapeRef &operator=(const std::vector<int64_t> &other) {
|
||||
data = other.data();
|
||||
size = other.size();
|
||||
return *this;
|
||||
}
|
||||
const int64_t *data;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
struct RelocTable {
|
||||
NDObject **inputs;
|
||||
size_t inputs_size;
|
||||
NDObject **outputs;
|
||||
size_t outputs_size;
|
||||
};
|
||||
|
||||
class Kernel {
|
||||
public:
|
||||
Kernel();
|
||||
~Kernel();
|
||||
|
||||
void Reset(KernelType type);
|
||||
void Reserve(size_t size);
|
||||
// int ParallelNext(); -- TODO: parallel fusion
|
||||
|
||||
NDObject *Load(void *addr, ShapeRef *shape, DType type);
|
||||
NDObject *Store(void *addr, NDObject *input);
|
||||
|
||||
NDObject *Unary(int op_type, NDObject *input);
|
||||
NDObject *Binary(int op_type, NDObject *lhs, NDObject *rhs);
|
||||
NDObject *Binary(int op_type, float val, NDObject *rhs);
|
||||
NDObject *Binary(int op_type, NDObject *lhs, float val);
|
||||
NDObject *Reduce(int op_type, NDObject *input, ShapeRef *dims, bool keepdims = false);
|
||||
NDObject *Select(NDObject *cond, NDObject *lhs, NDObject *rhs);
|
||||
|
||||
NDObject *Cast(NDObject *input, DType type);
|
||||
NDObject *Broadcast(NDObject *input, ShapeRef *shape);
|
||||
NDObject *Broadcast(float val, ShapeRef *shape, DType type, bool dummy_load);
|
||||
NDObject *Reshape(NDObject *input, ShapeRef *shape);
|
||||
NDObject *Copy(NDObject *input);
|
||||
|
||||
NDObject *ElemAny(NDObject *input);
|
||||
|
||||
int CodeGen();
|
||||
int Launch(void *stream);
|
||||
int Launch(const RelocTable &reloc_table, void **inputs, void **outputs, void *stream);
|
||||
int Launch(NDObject **op, int size, void *stream);
|
||||
|
||||
ShapeRef *GetShape(NDObject *op) const;
|
||||
DType GetDType(NDObject *op) const;
|
||||
|
||||
const char *DisAssemble();
|
||||
VKernel *GetImpl() const { return kernel_; }
|
||||
|
||||
private:
|
||||
VKernel *kernel_;
|
||||
};
|
||||
} // namespace dvm
|
||||
#endif // _DVM_H_
|
|
@ -0,0 +1,419 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/kernel/dvm/dvm_kernel_build.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include "plugin/device/ascend/kernel/dvm/dvm_kernel_mod.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "include/backend/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
enum OpType {
|
||||
OP_UNARY,
|
||||
OP_BINARY,
|
||||
OP_RESHAPE,
|
||||
OP_BROADCAST,
|
||||
OP_CAST,
|
||||
OP_NEG,
|
||||
OP_SELECT,
|
||||
OP_RSQRT,
|
||||
OP_ASSIGN,
|
||||
OP_ELEMENY,
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, std::pair<OpType, int>> op_type_map = {
|
||||
{"Abs", {OP_UNARY, dvm::UnaryOpType::kAbs}},
|
||||
{"Exp", {OP_UNARY, dvm::UnaryOpType::kExp}},
|
||||
{"Log", {OP_UNARY, dvm::UnaryOpType::kLog}},
|
||||
{"Sqrt", {OP_UNARY, dvm::UnaryOpType::kSqrt}},
|
||||
{"Neg", {OP_NEG, 0}},
|
||||
{"Cast", {OP_CAST, 0}},
|
||||
{"Add", {OP_BINARY, dvm::BinaryOpType::kAdd}},
|
||||
{"Sub", {OP_BINARY, dvm::BinaryOpType::kSub}},
|
||||
{"Mul", {OP_BINARY, dvm::BinaryOpType::kMul}},
|
||||
{"Div", {OP_BINARY, dvm::BinaryOpType::kDiv}},
|
||||
{"RealDiv", {OP_BINARY, dvm::BinaryOpType::kDiv}},
|
||||
{"Greater", {OP_BINARY, dvm::BinaryOpType::kGreater}},
|
||||
{"Maximum", {OP_BINARY, dvm::BinaryOpType::kMaximum}},
|
||||
{"Minimum", {OP_BINARY, dvm::BinaryOpType::kMinimum}},
|
||||
{"BroadcastTo", {OP_BROADCAST, 0}},
|
||||
{"GreaterEqual", {OP_BINARY, dvm::BinaryOpType::kGreaterEqual}},
|
||||
{"Less", {OP_BINARY, dvm::BinaryOpType::kLess}},
|
||||
{"LessEqual", {OP_BINARY, dvm::BinaryOpType::kLessEqual}},
|
||||
{"Equal", {OP_BINARY, dvm::BinaryOpType::kEqual}},
|
||||
{"NotEqual", {OP_BINARY, dvm::BinaryOpType::kNotEqual}},
|
||||
{"Reciprocal", {OP_UNARY, dvm::UnaryOpType::kReciprocal}},
|
||||
{"Reshape", {OP_RESHAPE, 0}},
|
||||
{"Select", {OP_SELECT, 0}},
|
||||
{"LogicalNot", {OP_UNARY, dvm::UnaryOpType::kLogicalNot}},
|
||||
{"LogicalOr", {OP_BINARY, dvm::BinaryOpType::kLogicalOr}},
|
||||
{"LogicalAnd", {OP_BINARY, dvm::BinaryOpType::kLogicalAnd}},
|
||||
{"Rsqrt", {OP_RSQRT, 0}},
|
||||
{"Assign", {OP_ASSIGN, 0}},
|
||||
{"ElemAny", {OP_ELEMENY, 0}},
|
||||
{"IsFinite", {OP_UNARY, dvm::UnaryOpType::kIsFinite}}};
|
||||
|
||||
class OpBuilder {
|
||||
public:
|
||||
OpBuilder(dvm::Kernel *kernel, const AnfNodePtrList &outputs, std::unordered_map<AnfNodePtr, ShapeRefPtr> *shapes_ref,
|
||||
std::vector<ShapeVector> *shapes_ref_source, bool empty_input)
|
||||
: kernel_(kernel), shapes_ref_(shapes_ref), shapes_ref_source_(shapes_ref_source), empty_input_(empty_input) {
|
||||
for (const auto &node : outputs) {
|
||||
outputs_[node] = nullptr;
|
||||
}
|
||||
}
|
||||
~OpBuilder() = default;
|
||||
|
||||
std::pair<float, TypeId> GetScalarFromTensor(const AnfNodePtr &node) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto input_tensor = value_node->value()->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||
auto type_id = input_tensor->data_type();
|
||||
float scalar;
|
||||
if (type_id == TypeId::kNumberTypeFloat32) {
|
||||
scalar = static_cast<float *>(input_tensor->data_c())[0];
|
||||
} else if (type_id == TypeId::kNumberTypeFloat16) {
|
||||
scalar = static_cast<float>(static_cast<float16 *>(input_tensor->data_c())[0]);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Data type of scalar value input only supports float, but got: " << TypeIdToString(type_id)
|
||||
<< " node: " << node->fullname_with_scope();
|
||||
}
|
||||
return {scalar, type_id};
|
||||
}
|
||||
|
||||
void Emit(const AnfNodePtr &anf_node) {
|
||||
auto node = anf_node->cast<CNodePtr>();
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
auto op_type = op_type_map.find(prim_name);
|
||||
if (op_type == op_type_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "unsupported fused op: " << prim_name;
|
||||
}
|
||||
switch (op_type->second.first) {
|
||||
case OP_CAST: {
|
||||
auto dst_dtype = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
auto op = EmitCast(GetInput(node->input(1)), dst_dtype);
|
||||
EmitOp(anf_node, op);
|
||||
break;
|
||||
}
|
||||
case OP_SELECT: {
|
||||
auto op = kernel_->Select(GetInput(node->input(1)), GetInput(node->input(2)), GetInput(node->input(3)));
|
||||
EmitOp(anf_node, op);
|
||||
break;
|
||||
}
|
||||
case OP_UNARY: {
|
||||
auto op = kernel_->Unary(op_type->second.second, GetInput(node->input(1)));
|
||||
EmitOp(anf_node, op);
|
||||
break;
|
||||
}
|
||||
case OP_RSQRT: {
|
||||
auto sqrt_op = kernel_->Unary(dvm::UnaryOpType::kSqrt, GetInput(node->input(1)));
|
||||
auto op = kernel_->Unary(dvm::UnaryOpType::kReciprocal, sqrt_op);
|
||||
EmitOp(anf_node, op);
|
||||
break;
|
||||
}
|
||||
case OP_RESHAPE: {
|
||||
auto shape_ref = CacheShape(node, 2);
|
||||
auto op = kernel_->Reshape(GetInput(node->input(1)), shape_ref);
|
||||
EmitOp(anf_node, op);
|
||||
break;
|
||||
}
|
||||
case OP_BINARY: {
|
||||
auto input1 = node->input(1);
|
||||
auto input2 = node->input(2);
|
||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input1->Shape())[kShape];
|
||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input2->Shape())[kShape];
|
||||
auto input1_sz = std::accumulate(input1_shape.begin(), input1_shape.end(), 1, std::multiplies{});
|
||||
auto input2_sz = std::accumulate(input2_shape.begin(), input2_shape.end(), 1, std::multiplies{});
|
||||
bool input1_scalar = (input1_sz == 1 && input1->cast<ValueNodePtr>() != nullptr);
|
||||
bool input2_scalar = (input2_sz == 1 && input2->cast<ValueNodePtr>() != nullptr);
|
||||
dvm::NDObject *op = nullptr;
|
||||
if (input1_scalar) {
|
||||
float scalar = GetScalarFromTensor(input1).first;
|
||||
op = kernel_->Binary(op_type->second.second, scalar, GetInput(input2));
|
||||
} else if (input2_scalar) {
|
||||
float scalar = GetScalarFromTensor(input2).first;
|
||||
op = kernel_->Binary(op_type->second.second, GetInput(input1), scalar);
|
||||
} else {
|
||||
op = kernel_->Binary(op_type->second.second, GetInput(input1), GetInput(input2));
|
||||
}
|
||||
EmitOp(anf_node, op);
|
||||
break;
|
||||
}
|
||||
case OP_BROADCAST: {
|
||||
auto input = node->input(1);
|
||||
auto shape_ref = CacheShape(node, 2);
|
||||
auto op = input->isa<ValueNode>() ? EmitScalarBroadcast(input, shape_ref)
|
||||
: kernel_->Broadcast(GetInput(input), shape_ref);
|
||||
EmitOp(anf_node, op);
|
||||
break;
|
||||
}
|
||||
case OP_NEG: {
|
||||
auto input = node->input(1);
|
||||
auto op = kernel_->Binary(dvm::BinaryOpType::kMul, GetInput(input), -1.0);
|
||||
EmitOp(anf_node, op);
|
||||
break;
|
||||
}
|
||||
case OP_ASSIGN: {
|
||||
auto out_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
auto input2 = EmitCast(GetInput(node->input(2)), out_type);
|
||||
// store the second input of assign to the corresponding parameter of subgraph
|
||||
auto input1 = node->input(1);
|
||||
ops_map_[input1] = kernel_->Store(nullptr, input2);
|
||||
outputs_[input1] = ops_map_[input1];
|
||||
if (outputs_.find(node) != outputs_.end()) {
|
||||
// store the second input of assign to the output of subgraph
|
||||
ops_map_[anf_node] = kernel_->Store(nullptr, input2);
|
||||
outputs_[node] = ops_map_[anf_node];
|
||||
}
|
||||
break;
|
||||
}
|
||||
case OP_ELEMENY: {
|
||||
auto dst_dtype = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
auto op = kernel_->ElemAny(EmitCast(GetInput(node->input(1)), dst_dtype));
|
||||
EmitOp(anf_node, op);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << op_type->second << " is unsupported op type.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
dvm::NDObject *GetLoad(const AnfNodePtr &node) {
|
||||
auto it = inputs_.find(node);
|
||||
return (it == inputs_.end() ? nullptr : it->second);
|
||||
}
|
||||
|
||||
dvm::NDObject *GetStore(const AnfNodePtr &node) {
|
||||
auto it = outputs_.find(node);
|
||||
return (it == outputs_.end() ? nullptr : it->second);
|
||||
}
|
||||
|
||||
private:
|
||||
dvm::NDObject *GetInput(const AnfNodePtr &node) {
|
||||
auto it = ops_map_.find(node);
|
||||
if (it == ops_map_.end()) {
|
||||
dvm::NDObject *op = nullptr;
|
||||
if (node->isa<ValueNode>()) {
|
||||
shapes_ref_source_->push_back({1});
|
||||
(*shapes_ref_)[node] = std::make_shared<dvm::ShapeRef>(shapes_ref_source_->back());
|
||||
op = EmitScalarBroadcast(node, (*shapes_ref_)[node].get());
|
||||
} else if (node->isa<Parameter>()) {
|
||||
// hit subgraph input
|
||||
auto type_id = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
auto iter = ms_type_map.find(type_id);
|
||||
if (iter == ms_type_map.end()) {
|
||||
MS_LOG(EXCEPTION) << node->ToString() << " 's type " << TypeIdToString(type_id)
|
||||
<< " is unsupported data type.";
|
||||
}
|
||||
auto shape = AnfAlgo::GetOutputDeviceShape(node, 0);
|
||||
shapes_ref_source_->push_back(shape);
|
||||
(*shapes_ref_)[node] = std::make_shared<dvm::ShapeRef>(shapes_ref_source_->back());
|
||||
op = kernel_->Load(nullptr, (*shapes_ref_)[node].get(), iter->second);
|
||||
inputs_[node] = op;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << node->DebugString() << " is unsupported node type.";
|
||||
}
|
||||
ops_map_[node] = op;
|
||||
return op;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
dvm::ShapeRef *CacheShape(const CNodePtr &node, size_t input_idx) {
|
||||
auto shape = AnfAlgo::GetOutputDeviceShape(node, 0);
|
||||
if (IsDynamic(shape)) {
|
||||
// Although param is subgraph input, there is no need to emit a Load op
|
||||
// for it, because it's value is only needed in infer shape
|
||||
auto param = node->input(input_idx);
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
if (!param->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "For " << node->fullname_with_scope() << ", input[" << (input_idx - 1)
|
||||
<< "] must be a Parameter, but got: " << param->ToString();
|
||||
}
|
||||
(*shapes_ref_)[param] = std::make_shared<dvm::ShapeRef>();
|
||||
return (*shapes_ref_)[param].get();
|
||||
}
|
||||
shapes_ref_source_->push_back(shape);
|
||||
(*shapes_ref_)[node] = std::make_shared<dvm::ShapeRef>(shapes_ref_source_->back());
|
||||
return (*shapes_ref_)[node].get();
|
||||
}
|
||||
|
||||
dvm::NDObject *EmitScalarBroadcast(const AnfNodePtr &node, dvm::ShapeRef *shape) {
|
||||
auto [scalar, type_id] = GetScalarFromTensor(node);
|
||||
auto v_type_id = ms_type_map[type_id];
|
||||
auto op = kernel_->Broadcast(scalar, shape, v_type_id, empty_input_);
|
||||
if (empty_input_) {
|
||||
empty_input_ = false; // now we have a fake input
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
dvm::NDObject *EmitCast(dvm::NDObject *obj, TypeId dst_type) {
|
||||
auto it = ms_type_map.find(dst_type);
|
||||
if (it == ms_type_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unsupported data type '" << TypeIdToString(dst_type) << "' for Cast";
|
||||
}
|
||||
if (kernel_->GetDType(obj) == it->second) {
|
||||
return obj;
|
||||
}
|
||||
return kernel_->Cast(obj, it->second);
|
||||
}
|
||||
|
||||
void EmitOp(const AnfNodePtr &node, dvm::NDObject *obj) {
|
||||
ops_map_[node] = obj;
|
||||
if (outputs_.find(node) != outputs_.end()) {
|
||||
// hit subgraph output
|
||||
auto out_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
obj = EmitCast(obj, out_type);
|
||||
outputs_[node] = kernel_->Store(nullptr, obj);
|
||||
}
|
||||
}
|
||||
|
||||
dvm::Kernel *kernel_;
|
||||
std::unordered_map<AnfNodePtr, dvm::NDObject *> inputs_;
|
||||
std::unordered_map<AnfNodePtr, dvm::NDObject *> outputs_;
|
||||
std::unordered_map<AnfNodePtr, dvm::NDObject *> ops_map_;
|
||||
std::unordered_map<AnfNodePtr, ShapeRefPtr> *shapes_ref_;
|
||||
std::vector<ShapeVector> *shapes_ref_source_;
|
||||
static std::unordered_map<dvm::DType, TypeId> v_type_map;
|
||||
static std::unordered_map<TypeId, dvm::DType> ms_type_map;
|
||||
bool empty_input_{false};
|
||||
};
|
||||
|
||||
std::unordered_map<dvm::DType, TypeId> OpBuilder::v_type_map = {{dvm::DType::kFloat32, TypeId::kNumberTypeFloat32},
|
||||
{dvm::DType::kFloat16, TypeId::kNumberTypeFloat16},
|
||||
{dvm::DType::kInt8, TypeId::kNumberTypeBool},
|
||||
{dvm::DType::kInt32, TypeId::kNumberTypeInt32}};
|
||||
|
||||
std::unordered_map<TypeId, dvm::DType> OpBuilder::ms_type_map = {{TypeId::kNumberTypeFloat32, dvm::DType::kFloat32},
|
||||
{TypeId::kNumberTypeFloat16, dvm::DType::kFloat16},
|
||||
{TypeId::kNumberTypeBool, dvm::DType::kInt8},
|
||||
{TypeId::kNumberTypeInt32, dvm::DType::kInt32}};
|
||||
|
||||
class DvmKernelBuilder {
|
||||
public:
|
||||
DvmKernelBuilder() = default;
|
||||
~DvmKernelBuilder() = default;
|
||||
|
||||
void Construct(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
AnfNodePtr end_node = graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(end_node);
|
||||
auto ret_node = end_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(ret_node);
|
||||
auto out_node = ret_node->input(1)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
std::vector<AnfNodePtr> outputs;
|
||||
if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
|
||||
auto tuple = out_node->cast<CNodePtr>();
|
||||
for (size_t i = 1; i < tuple->size(); ++i) {
|
||||
outputs.emplace_back(tuple->input(i));
|
||||
}
|
||||
end_node = out_node;
|
||||
} else {
|
||||
outputs.emplace_back(out_node);
|
||||
}
|
||||
const auto ¶ms = graph->parameters();
|
||||
std::unordered_map<AnfNodePtr, ShapeRefPtr> shapes_ref;
|
||||
OpBuilder builder(kernel_mod_->Kernel(), outputs, &shapes_ref, kernel_mod_->ShapesSource(), params.empty());
|
||||
auto nodes = TopoSort(ret_node);
|
||||
for (const auto &node : nodes) {
|
||||
if (node == end_node) break;
|
||||
if (node->isa<CNode>()) {
|
||||
builder.Emit(node);
|
||||
}
|
||||
}
|
||||
for (const auto &iter : shapes_ref) {
|
||||
kernel_mod_->CacheShapeRef(iter.second);
|
||||
}
|
||||
// cache kernel's inputs and outputs from subgraph's inputs and outputs
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
auto shape_iter = shapes_ref.find(params[i]);
|
||||
auto ref = shape_iter == shapes_ref.end() ? nullptr : shape_iter->second.get();
|
||||
kernel_mod_->UpdateInputShapeRef(i, ref);
|
||||
if (auto load = builder.GetLoad(params[i]); load != nullptr) {
|
||||
kernel_mod_->CacheLoad(load, i);
|
||||
}
|
||||
if (auto store = builder.GetStore(params[i]); store != nullptr) {
|
||||
kernel_mod_->CacheStore(store, i);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
if (auto store = builder.GetStore(outputs[i]); store != nullptr) {
|
||||
kernel_mod_->CacheStore(store, i + params.size());
|
||||
}
|
||||
}
|
||||
kernel_mod_->UpdateIO();
|
||||
}
|
||||
|
||||
KernelModPtr Create(const AnfNodePtr &anf_node) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto scope = cnode->fullname_with_scope();
|
||||
MS_LOG(INFO) << "Start creating kernel module for node: " << scope;
|
||||
// Create kernel mod
|
||||
auto is_dynamic = common::AnfAlgo::IsDynamicShape(anf_node);
|
||||
kernel_mod_ = std::make_shared<DvmKernelMod>(is_dynamic);
|
||||
auto inputs_type = AnfAlgo::GetAllInputDeviceTypes(cnode);
|
||||
auto outputs_type = AnfAlgo::GetAllOutputDeviceTypes(cnode);
|
||||
kernel_mod_->Initialize(inputs_type, outputs_type);
|
||||
// FuncGraph --> Dvm Kernel
|
||||
auto func_graph = GetCNodeFuncGraph(cnode);
|
||||
Construct(func_graph);
|
||||
if (!is_dynamic) {
|
||||
// Static shape need codegen
|
||||
std::vector<ShapeVector> inputs_shape(inputs_type.size());
|
||||
for (size_t i = 0; i < inputs_type.size(); ++i) {
|
||||
inputs_shape[i] = AnfAlgo::GetInputDeviceShape(cnode, i);
|
||||
}
|
||||
std::vector<ShapeVector> outputs_shape(outputs_type.size());
|
||||
for (size_t i = 0; i < outputs_type.size(); ++i) {
|
||||
outputs_shape[i] = AnfAlgo::GetOutputDeviceShape(cnode, i);
|
||||
}
|
||||
kernel_mod_->CodeGen(inputs_shape, outputs_shape);
|
||||
} else {
|
||||
// Dynamic shape need create a prim to hold the infer shape function
|
||||
auto prim = std::make_shared<Primitive>(scope);
|
||||
prim->set_attr("infer_shape_functor", std::make_shared<DvmInfer>("dvm_infer_functor", kernel_mod_.get()));
|
||||
if (!std::static_pointer_cast<KernelMod>(kernel_mod_)->Init(prim, {}, {})) {
|
||||
MS_LOG(EXCEPTION) << "Initialize kernel module failed for node: " << scope;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "End creating kernel module for node: " << scope;
|
||||
return kernel_mod_;
|
||||
}
|
||||
|
||||
private:
|
||||
DvmKernelModPtr kernel_mod_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
KernelModPtr DvmOpBuild(const AnfNodePtr &anf_node) {
|
||||
DvmKernelBuilder kernel_builder;
|
||||
return kernel_builder.Create(anf_node);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_DVM_KERNEL_BUILD_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_DVM_KERNEL_BUILD_H_
|
||||
#include "kernel/kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
KernelModPtr DvmOpBuild(const AnfNodePtr &anf_node);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_DVM_KERNEL_BUILD_H_
|
|
@ -0,0 +1,143 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/kernel/dvm/dvm_kernel_mod.h"
|
||||
#include <algorithm>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
BaseShapePtr DvmInfer::InferShape(const AbstractBasePtrList &args) { return kernel_->InferShape(args); }
|
||||
|
||||
DvmKernelMod::DvmKernelMod(bool is_dynamic) {
|
||||
auto kernel_type = is_dynamic ? dvm::KernelType::kDynShape : dvm::KernelType::kStaticShape;
|
||||
kernel_.Reset(kernel_type);
|
||||
}
|
||||
|
||||
bool DvmKernelMod::Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &,
|
||||
const std::vector<KernelTensor *> &outputs, void *stream_ptr) {
|
||||
for (size_t i = 0; i < inputs_addr_.size(); ++i) {
|
||||
inputs_addr_[i] = inputs[inputs_idx_[i]]->device_ptr();
|
||||
}
|
||||
for (size_t i = 0; i < outputs_addr_.size(); ++i) {
|
||||
auto idx = outputs_idx_[i];
|
||||
outputs_addr_[i] = idx < inputs.size() ? inputs[idx]->device_ptr() : outputs[idx - inputs.size()]->device_ptr();
|
||||
}
|
||||
auto ret = kernel_.Launch(reloc_table_, inputs_addr_.data(), outputs_addr_.data(), stream_ptr);
|
||||
return ret == 0;
|
||||
}
|
||||
|
||||
void DvmKernelMod::Initialize(const std::vector<TypeId> &inputs_type, const std::vector<TypeId> &outputs_type) {
|
||||
inputs_type_byte_.clear();
|
||||
inputs_type_byte_.reserve(inputs_type.size());
|
||||
(void)std::transform(inputs_type.begin(), inputs_type.end(), std::back_inserter(inputs_type_byte_),
|
||||
[](const TypeId &type_id) { return GetTypeByte(TypeIdToType(type_id)); });
|
||||
outputs_type_byte_.clear();
|
||||
outputs_type_byte_.reserve(outputs_type.size());
|
||||
(void)std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_byte_),
|
||||
[](const TypeId &type_id) { return GetTypeByte(TypeIdToType(type_id)); });
|
||||
input_size_list_.resize(inputs_type.size(), 1);
|
||||
output_size_list_.resize(outputs_type.size(), 1);
|
||||
inputs_shape_.resize(inputs_type.size());
|
||||
outputs_shape_.resize(outputs_type.size());
|
||||
shapes_ref_source_.reserve(inputs_type.size());
|
||||
inputs_shape_ref_.resize(inputs_type.size());
|
||||
inputs_.reserve(inputs_type.size());
|
||||
outputs_.reserve(outputs_type.size());
|
||||
inputs_idx_.reserve(inputs_type.size());
|
||||
outputs_idx_.reserve(outputs_type.size());
|
||||
}
|
||||
|
||||
void DvmKernelMod::CodeGen(const std::vector<ShapeVector> &inputs_shape,
|
||||
const std::vector<ShapeVector> &outputs_shape) {
|
||||
for (size_t i = 0; i < inputs_shape.size(); ++i) {
|
||||
input_size_list_[i] = inputs_type_byte_[i];
|
||||
for (auto sh : inputs_shape[i]) {
|
||||
input_size_list_[i] *= LongToSize(sh);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < outputs_shape.size(); ++i) {
|
||||
output_size_list_[i] = outputs_type_byte_[i];
|
||||
for (auto sh : outputs_shape[i]) {
|
||||
output_size_list_[i] *= LongToSize(sh);
|
||||
}
|
||||
}
|
||||
kernel_.CodeGen();
|
||||
}
|
||||
|
||||
BaseShapePtr DvmKernelMod::InferShape(const AbstractBasePtrList &inputs_abs) {
|
||||
BaseShapePtr result{nullptr};
|
||||
// update input shape
|
||||
for (size_t i = 0; i < inputs_abs.size(); ++i) {
|
||||
// to do: if input value is needed in infer shape, then we should fetch input value here
|
||||
inputs_shape_[i] = inputs_abs[i]->GetShape()->GetShapeVector();
|
||||
if (inputs_shape_ref_[i] != nullptr) {
|
||||
*inputs_shape_ref_[i] = inputs_shape_[i];
|
||||
}
|
||||
input_size_list_[i] = inputs_type_byte_[i];
|
||||
for (auto sh : inputs_shape_[i]) {
|
||||
input_size_list_[i] *= LongToSize(sh);
|
||||
}
|
||||
}
|
||||
// re-codegen by new input shape
|
||||
kernel_.CodeGen();
|
||||
// update output shape
|
||||
for (size_t i = 0; i < outputs_.size(); ++i) {
|
||||
auto idx = outputs_idx_[i];
|
||||
if (idx >= inputs_shape_.size()) {
|
||||
idx -= inputs_shape_.size();
|
||||
auto shape_ref = kernel_.GetShape(outputs_[i]);
|
||||
outputs_shape_[idx] = ShapeVector(shape_ref->data, shape_ref->data + shape_ref->size);
|
||||
output_size_list_[idx] = outputs_type_byte_[idx];
|
||||
for (auto sh : outputs_shape_[idx]) {
|
||||
output_size_list_[idx] *= LongToSize(sh);
|
||||
}
|
||||
}
|
||||
}
|
||||
// update output abstract
|
||||
if (outputs_shape_.size() > 1) {
|
||||
abstract::BaseShapePtrList out_shapes(outputs_shape_.size());
|
||||
for (size_t i = 0; i < outputs_shape_.size(); ++i) {
|
||||
out_shapes[i] = std::make_shared<abstract::TensorShape>(outputs_shape_[i]);
|
||||
}
|
||||
result = std::make_shared<abstract::TupleShape>(out_shapes);
|
||||
} else {
|
||||
result = std::make_shared<abstract::TensorShape>(outputs_shape_.front());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void DvmKernelMod::CacheLoad(dvm::NDObject *obj, size_t idx) {
|
||||
inputs_.push_back(obj);
|
||||
inputs_idx_.push_back(idx);
|
||||
}
|
||||
|
||||
void DvmKernelMod::CacheStore(dvm::NDObject *obj, size_t idx) {
|
||||
outputs_.push_back(obj);
|
||||
outputs_idx_.push_back(idx);
|
||||
}
|
||||
|
||||
void DvmKernelMod::UpdateIO() {
|
||||
inputs_addr_.resize(inputs_.size());
|
||||
outputs_addr_.resize(outputs_.size());
|
||||
reloc_table_.inputs = inputs_.data();
|
||||
reloc_table_.outputs = outputs_.data();
|
||||
reloc_table_.inputs_size = inputs_.size();
|
||||
reloc_table_.outputs_size = outputs_.size();
|
||||
}
|
||||
|
||||
void DvmKernelMod::UpdateInputShapeRef(size_t input_idx, dvm::ShapeRef *ref) { inputs_shape_ref_[input_idx] = ref; }
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_DVM_KERNEL_MOD_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_DVM_KERNEL_MOD_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ir/functor.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "plugin/device/ascend/kernel/dvm/dvm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using ShapeRefPtr = std::shared_ptr<dvm::ShapeRef>;
|
||||
class DvmKernelMod;
|
||||
class DvmInfer : public InferShapeFunctor {
|
||||
public:
|
||||
DvmInfer(const std::string &name, DvmKernelMod *kernel) : InferShapeFunctor(name) { kernel_ = kernel; }
|
||||
~DvmInfer() override = default;
|
||||
MS_DECLARE_PARENT(DvmInfer, InferShapeFunctor)
|
||||
BaseShapePtr InferShape(const AbstractBasePtrList &args) override;
|
||||
|
||||
private:
|
||||
DvmKernelMod *kernel_;
|
||||
};
|
||||
|
||||
class DvmKernelMod : public KernelMod {
|
||||
public:
|
||||
explicit DvmKernelMod(bool is_dynamic);
|
||||
~DvmKernelMod() = default;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override { MS_LOG(EXCEPTION) << "This interface is not support in VKernel."; }
|
||||
|
||||
bool Init(const std::vector<KernelTensor *> &, const std::vector<KernelTensor *> &) override { return true; }
|
||||
|
||||
int Resize(const std::vector<KernelTensor *> &, const std::vector<KernelTensor *> &) override { return 0; }
|
||||
|
||||
bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace,
|
||||
const std::vector<KernelTensor *> &outputs, void *stream_ptr) override;
|
||||
|
||||
void Initialize(const std::vector<TypeId> &inputs_type, const std::vector<TypeId> &outputs_type);
|
||||
|
||||
// used in static shape
|
||||
void CodeGen(const std::vector<ShapeVector> &inputs_shape, const std::vector<ShapeVector> &outputs_shape);
|
||||
|
||||
// used in dynamic shape
|
||||
BaseShapePtr InferShape(const AbstractBasePtrList &inputs_abs);
|
||||
|
||||
dvm::Kernel *Kernel() { return &kernel_; }
|
||||
|
||||
std::vector<ShapeVector> *ShapesSource() { return &shapes_ref_source_; }
|
||||
|
||||
void CacheShapeRef(const ShapeRefPtr &shape_ref) { shapes_ref_.push_back(shape_ref); }
|
||||
|
||||
void CacheLoad(dvm::NDObject *obj, size_t idx);
|
||||
|
||||
void CacheStore(dvm::NDObject *obj, size_t idx);
|
||||
|
||||
void UpdateIO();
|
||||
|
||||
void UpdateInputShapeRef(size_t input_idx, dvm::ShapeRef *ref);
|
||||
|
||||
private:
|
||||
std::vector<ShapeVector> inputs_shape_;
|
||||
std::vector<ShapeVector> outputs_shape_;
|
||||
std::vector<ShapeVector> shapes_ref_source_; // to ensure the shape which is pointed by ShapeRef keeps alive
|
||||
std::vector<ShapeRefPtr> shapes_ref_; // manage the dynamically allocated ShapeRef
|
||||
std::vector<dvm::ShapeRef *> inputs_shape_ref_; // point to the latest inputs shape, which is used in infer shape
|
||||
std::vector<dvm::NDObject *> inputs_; // cache Load
|
||||
std::vector<dvm::NDObject *> outputs_; // cache Store
|
||||
std::vector<void *> inputs_addr_;
|
||||
std::vector<void *> outputs_addr_;
|
||||
dvm::RelocTable reloc_table_;
|
||||
std::vector<size_t> inputs_idx_;
|
||||
std::vector<size_t> outputs_idx_;
|
||||
std::vector<size_t> inputs_type_byte_;
|
||||
std::vector<size_t> outputs_type_byte_;
|
||||
dvm::Kernel kernel_;
|
||||
};
|
||||
using DvmKernelModPtr = std::shared_ptr<DvmKernelMod>;
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_KERNEL_DVM_KERNEL_MOD_H_
|
|
@ -0,0 +1,3 @@
|
|||
[lib information]
|
||||
git branch: master
|
||||
commit id: 9928302bfbd57493e6d8b6d527c69cf5bb67caa8
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:65b582850fd2b541b9d7c2a8edd701c0e3e4b9c75bdf1f8a335e75f9b52f12a0
|
||||
size 396286
|
|
@ -0,0 +1,3 @@
|
|||
[lib information]
|
||||
git branch: master
|
||||
commit id: 9928302bfbd57493e6d8b6d527c69cf5bb67caa8
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:00b0fe60b893ecd4597a382ef7d2baa5a12da37160080f17f97be99a92b2e17e
|
||||
size 382102
|
Loading…
Reference in New Issue