clean code
This commit is contained in:
parent
66761f2808
commit
728ffa6370
|
@ -628,7 +628,7 @@ void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr>
|
|||
GetValidKernelNodes(func_graph, node_list);
|
||||
|
||||
auto parameters = func_graph->parameters();
|
||||
input_list->insert(input_list->cbegin(), parameters.begin(), parameters.end());
|
||||
(void)input_list->insert(input_list->cbegin(), parameters.begin(), parameters.end());
|
||||
|
||||
GetFuncGraphOutputNodes(func_graph, output_list);
|
||||
}
|
||||
|
@ -1353,22 +1353,10 @@ void SyncOutInRef(const KernelAttr &from_kernel_attr, KernelAttr *to_kernel_attr
|
|||
for (const auto &ref : out_in_ref) {
|
||||
(void)to_kernel_attr->AddOutInRef(ref.first, ref.second);
|
||||
}
|
||||
to_kernel_attr->AddAllOutInRef(all_out_in_ref);
|
||||
(void)to_kernel_attr->AddAllOutInRef(all_out_in_ref);
|
||||
}
|
||||
|
||||
namespace broadcast_utils {
|
||||
bool IsBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs) {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return true;
|
||||
}
|
||||
for (size_t i = 0; i < lhs.size(); i++) {
|
||||
if (lhs[i] != rhs[i]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AlignedBroadCastShape(size_t align_rank, std::vector<size_t> *broadcast, std::vector<size_t> *lhs,
|
||||
std::vector<size_t> *rhs) {
|
||||
if (broadcast == nullptr || lhs == nullptr || rhs == nullptr) {
|
||||
|
|
|
@ -364,7 +364,6 @@ class MatchKernelHelper {
|
|||
};
|
||||
|
||||
namespace broadcast_utils {
|
||||
bool IsBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs);
|
||||
bool AlignedBroadCastShape(size_t align_rank, std::vector<size_t> *broadcast, std::vector<size_t> *lhs,
|
||||
std::vector<size_t> *rhs);
|
||||
} // namespace broadcast_utils
|
||||
|
|
|
@ -114,6 +114,8 @@ class OpInfo {
|
|||
dynamic_compile_static_ = opinfo.dynamic_compile_static_;
|
||||
op_pattern_ = opinfo.op_pattern();
|
||||
processor_ = opinfo.processor_;
|
||||
input_to_attr_index_ = opinfo.input_to_attr_index_;
|
||||
real_input_index_ = opinfo.real_input_index_;
|
||||
need_check_supported_ = opinfo.need_check_supported();
|
||||
is_dynamic_format_ = opinfo.is_dynamic_format();
|
||||
for (const auto &attr : opinfo.attrs_ptr()) {
|
||||
|
|
|
@ -150,8 +150,8 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
|
|||
std::map<size_t, size_t> real_index;
|
||||
std::map<size_t, size_t> ori_index;
|
||||
for (size_t i = 0; i < real_input_index.size(); ++i) {
|
||||
real_index.emplace(std::pair{i, real_input_index.at(i)});
|
||||
ori_index.emplace(std::pair{real_input_index.at(i), i});
|
||||
(void)real_index.emplace(std::pair{i, real_input_index.at(i)});
|
||||
(void)ori_index.emplace(std::pair{real_input_index.at(i), i});
|
||||
}
|
||||
op_info->set_real_input_index(std::pair{real_index, ori_index});
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ void CheckRtRetWithError(rtError_t error, const std::string &msg) {
|
|||
}
|
||||
|
||||
AscendDataQueueDynamic::AscendDataQueueDynamic(const size_t capacity)
|
||||
: DataQueue(capacity), stream_(0), node_info_(nullptr) {
|
||||
: DataQueue(capacity), stream_(nullptr), node_info_(nullptr) {
|
||||
auto context_key = device_context_->device_context_key();
|
||||
auto runtime_instance = dynamic_cast<ascend::AscendKernelRuntime *>(
|
||||
device::KernelRuntimeManager::Instance().GetKernelRuntime(context_key.device_name_, context_key.device_id_));
|
||||
|
|
|
@ -125,7 +125,7 @@ int BroadcastOpGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
if (is_null_input_) {
|
||||
return KRET_OK;
|
||||
}
|
||||
need_broadcast_ = broadcast_utils::IsBroadcast(lhs_shape_, rhs_shape_);
|
||||
need_broadcast_ = common::AnfAlgo::IsTensorBroadcast(lhs_shape_, rhs_shape_);
|
||||
if (!broadcast_utils::AlignedBroadCastShape(MAX_DIMS, &output_shape_, &lhs_shape_, &rhs_shape_)) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's dimension of input cannot be greater than " << MAX_DIMS
|
||||
<< ", but got " << lhs_shape_.size();
|
||||
|
|
|
@ -82,7 +82,7 @@ int BroadcastOpGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
if (is_null_input_) {
|
||||
return KRET_OK;
|
||||
}
|
||||
need_broadcast_ = broadcast_utils::IsBroadcast(x1_shape_, x2_shape_);
|
||||
need_broadcast_ = common::AnfAlgo::IsTensorBroadcast(x1_shape_, x2_shape_);
|
||||
// For x1_shape, x2_shape, dy_shape, it's validation has been done in core/ops/xxx.cc.
|
||||
// But we need check shape rank less equal to 7D.
|
||||
if (!broadcast_utils::AlignedBroadCastShape(kMaxDim, &dy_shape_, &x1_shape_, &x2_shape_)) {
|
||||
|
|
|
@ -86,7 +86,7 @@ int BroadcastOpGradGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator
|
|||
if (is_null_input_) {
|
||||
return KRET_OK;
|
||||
}
|
||||
need_broadcast_ = broadcast_utils::IsBroadcast(x1_shape_, x2_shape_);
|
||||
need_broadcast_ = common::AnfAlgo::IsTensorBroadcast(x1_shape_, x2_shape_);
|
||||
// For x1_shape, x2_shape, dy1_shape, it's validation has been done in core/ops/xxx.cc.
|
||||
// But we need check shape rank less equal to 7D.
|
||||
if (!broadcast_utils::AlignedBroadCastShape(kMaxDim, &sopd_grad_shape_, &x1_shape_, &x2_shape_)) {
|
||||
|
|
Loading…
Reference in New Issue