回退 'Pull Request !47520 : modify cpu_kernel_mod launch bug'

This commit is contained in:
yanghaoran 2023-01-11 03:53:59 +00:00 committed by Gitee
parent 2c788e1557
commit 6fdfcf57ad
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 45 additions and 28 deletions

View File

@ -1778,6 +1778,17 @@ KernelAttr GetKernelAttrFromTensors(const std::vector<KernelTensorPtr> &inputs,
return kernel_attr;
}
KernelAttr GetKernelAttrFromTypes(const std::vector<TypeId> &inputs, const std::vector<TypeId> &outputs) {
KernelAttr kernel_attr;
for (auto type_id : inputs) {
(void)kernel_attr.AddInputAttr(type_id);
}
for (auto type_id : outputs) {
(void)kernel_attr.AddOutputAttr(type_id);
}
return kernel_attr;
}
void SetCpuRefMapToKernelInfo(const CNodePtr &apply_kernel, const std::vector<KernelAttr> &apply_kernel_attrs) {
auto kernel_attrs = apply_kernel_attrs;
if (kernel_attrs.empty()) {

View File

@ -412,6 +412,7 @@ BACKEND_EXPORT KernelArgs AbstractArgsFromCNode(const CNodePtr &cnode, bool is_w
BACKEND_EXPORT KernelAttr GetKernelAttrFromTensors(const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs);
BACKEND_EXPORT KernelAttr GetKernelAttrFromTypes(const std::vector<TypeId> &inputs, const std::vector<TypeId> &outputs);
void SetCpuRefMapToKernelInfo(const CNodePtr &apply_kernel, const std::vector<KernelAttr> &apply_kernel_attrs);
Format GetFormatFromStrToEnum(const std::string &format_str);
BACKEND_EXPORT std::string GetFormatFromEnumToStr(Format format);

View File

@ -74,16 +74,10 @@ TypeId KernelTensor::GetDtype() const {
} else if (meta_type_ == kObjectTypeTuple) {
// Tuple
const TupleInfo &info = std::get<TupleInfo>(meta_);
if (info.base_->dynamic_len()) {
return info.base_->dynamic_len_element_abs()->BuildType()->type_id();
}
return info.base_->elements()[0]->BuildType()->type_id();
} else if (meta_type_ == kObjectTypeList) {
// List
const ListInfo &info = std::get<ListInfo>(meta_);
if (info.base_->dynamic_len()) {
return info.base_->dynamic_len_element_abs()->BuildType()->type_id();
}
return info.base_->elements()[0]->BuildType()->type_id();
} else {
// Tensor
@ -105,6 +99,27 @@ TypeId KernelTensor::GetDtype() const {
return kTypeUnknown;
}
TypeId KernelTensor::GetScalarDtype() const {
if (meta_type_ != kObjectTypeNumber) {
MS_LOG(EXCEPTION) << "meta_type must be scalar, but got " << meta_type_;
}
const ScalarInfo &info = std::get<ScalarInfo>(meta_);
auto info_type = info.base_->BuildType();
MS_EXCEPTION_IF_NULL(info_type);
return info_type->type_id();
}
TypeId KernelTensor::GetTupleElementDtype() const {
if (meta_type_ != kObjectTypeTuple) {
MS_LOG(EXCEPTION) << "meta_type must be tuple , but got " << meta_type_;
}
const TupleInfo &info = std::get<TupleInfo>(meta_);
if (info.base_->dynamic_len()) {
return info.base_->dynamic_len_element_abs()->BuildType()->type_id();
}
return info.base_->elements()[0]->BuildType()->type_id();
}
ShapeVector KernelTensor::GetShapeVector() const {
if (meta_type_ == kObjectTypeTensorType) {
// Tensor

View File

@ -288,6 +288,8 @@ class BACKEND_EXPORT KernelTensor {
// deprecated field for dynamic shape
const ShapeVector &GetDeviceShapeAdaptively() const;
void SetDeviceShapeAdaptively(const ShapeVector &device_shape_adaptively);
TypeId GetScalarDtype() const;
TypeId GetTupleElementDtype() const;
private:
TypeId meta_type_{kObjectTypeTensorType};
@ -395,8 +397,6 @@ class BACKEND_EXPORT KernelMod {
std::vector<size_t> output_size_list_;
std::vector<std::vector<int64_t>> output_shapes_;
std::vector<size_t> workspace_size_list_;
std::vector<KernelTensorPtr> inputs_;
std::vector<KernelTensorPtr> outputs_;
bool is_need_retrieve_output_shape_ = false;
uint32_t device_id_ = 0;
@ -404,6 +404,9 @@ class BACKEND_EXPORT KernelMod {
std::vector<AddressPtr> inputs_addr_;
std::vector<AddressPtr> workspaces_addr_;
std::vector<AddressPtr> outputs_addr_;
std::vector<KernelTensorPtr> inputs_;
std::vector<KernelTensorPtr> workspace_;
std::vector<KernelTensorPtr> outputs_;
};
using KernelModPtr = std::shared_ptr<KernelMod>;

View File

@ -89,25 +89,6 @@ std::vector<KernelAttr> NativeCpuKernelMod::GetSupportFromOpLib(const std::strin
return support_kernel_attrs;
}
bool NativeCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
MS_EXCEPTION_IF_CHECK_FAIL(this->inputs_.size() == inputs.size(), "inputs size check failed");
MS_EXCEPTION_IF_CHECK_FAIL(this->outputs_.size() == outputs.size(), "inputs size check failed");
auto it1 = this->inputs_.begin();
auto it2 = inputs.begin();
for (; it1 != this->inputs_.end() && it2 != inputs.end(); it1++, it2++) {
(*it1)->SetData((*it2));
}
it1 = this->outputs_.begin();
it2 = outputs.begin();
for (; it1 != this->outputs_.end() && it2 != outputs.end(); it1++, it2++) {
(*it1)->SetData((*it2));
}
return Launch(this->inputs_, this->outputs_, workspace);
}
int DeprecatedNativeCpuKernelMod::Resize(const BaseOperatorPtr &, const std::vector<KernelTensorPtr> &,
const std::vector<KernelTensorPtr> &,
const std::map<uint32_t, tensor::TensorPtr> &) {

View File

@ -143,8 +143,14 @@ class BACKEND_EXPORT NativeCpuKernelMod : public CpuKernelMod {
const std::vector<AddressPtr> &outputs, void * /*stream_ptr*/) override {
return Launch(inputs, workspace, outputs);
}
bool Launch(const std::vector<KernelTensorPtr> &inputs, const std::vector<KernelTensorPtr> &outputs,
const std::vector<AddressPtr> &workspace, void *) override {
return Launch(inputs, outputs, workspace);
}
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
const std::vector<AddressPtr> &outputs) {
return false;
}
virtual bool Launch(const std::vector<KernelTensorPtr> &inputs, const std::vector<KernelTensorPtr> &outputs,
const std::vector<AddressPtr> &workspace) {
return false;