!47865 fix MaxPool3DGradGrad bug
Merge pull request !47865 from lyqlola/clean
This commit is contained in:
commit
d1c45cd2f1
|
@ -243,7 +243,9 @@
|
|||
"axes": "axis"
|
||||
},
|
||||
"MaxPool3DGradGradD": {
|
||||
"ksize": "kernel_size"
|
||||
"ksize": "kernel_size",
|
||||
"pads": "pad_list",
|
||||
"data_format": "format"
|
||||
}
|
||||
},
|
||||
"AttrDefaultValue": {
|
||||
|
|
|
@ -54,8 +54,8 @@ class BACKEND_EXPORT DataQueue {
|
|||
virtual DataQueueStatus Front(std::vector<DataQueueItem> *data) const = 0;
|
||||
virtual DataQueueStatus Pop() = 0;
|
||||
virtual void SetThreadDevice() {}
|
||||
virtual size_t Size() { return size_; }
|
||||
virtual size_t Capacity() { return capacity_; }
|
||||
virtual size_t Size() const { return size_; }
|
||||
virtual size_t Capacity() const { return capacity_; }
|
||||
|
||||
protected:
|
||||
const std::string channel_name_;
|
||||
|
|
|
@ -488,6 +488,7 @@ constexpr auto kMaximumGradGradOpName = "MaximumGradGrad";
|
|||
constexpr auto kMaximumGradOpName = "MaximumGrad";
|
||||
constexpr auto kMaximumOpName = "Maximum";
|
||||
constexpr auto kMaxPool3DGradGradOpName = "MaxPool3DGradGrad";
|
||||
constexpr auto kMaxPool3DGradGradDOpName = "MaxPool3DGradGradD";
|
||||
constexpr auto kMaxPool3DGradOpName = "MaxPool3DGrad";
|
||||
constexpr auto kMaxPool3DOpName = "MaxPool3D";
|
||||
constexpr auto kMaxPoolGradOpName = "MaxPoolGrad";
|
||||
|
|
|
@ -68,7 +68,7 @@ class WingmanQueue : public DataQueue {
|
|||
DataQueueStatus Pop() override;
|
||||
bool IsEmpty() const override { return queue_.empty(); }
|
||||
bool IsFull() const override { return false; }
|
||||
size_t Size() override { return queue_.size(); }
|
||||
size_t Size() const override { return queue_.size(); }
|
||||
|
||||
private:
|
||||
std::queue<std::vector<DataQueueItem>> queue_;
|
||||
|
|
|
@ -98,7 +98,7 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
|
|||
|
||||
const BaseRef MaxPool3DGradGradFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto max_pool3d_grad_grad_prim = std::make_shared<Primitive>(kMaxPool3DGradGradOpName);
|
||||
auto max_pool3d_grad_grad_prim = std::make_shared<Primitive>(kMaxPool3DGradGradDOpName);
|
||||
return VectorRef({max_pool3d_grad_grad_prim, Xs});
|
||||
}
|
||||
|
||||
|
@ -113,7 +113,7 @@ const AnfNodePtr MaxPool3DGradGradFission::Process(const FuncGraphPtr &graph, co
|
|||
MS_LOG(INFO) << "The node " << cnode->DebugString() << " is not equal to " << kInputNum << " inputs";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kMaxPool3DGradGradOpName))};
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kMaxPool3DGradGradDOpName))};
|
||||
auto assist_const = CreateValueNode(cnode);
|
||||
(void)new_inputs.insert(new_inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
|
||||
(void)new_inputs.emplace_back(assist_const);
|
||||
|
|
|
@ -89,6 +89,7 @@ REG_ASCEND_VM_OP_ADAPTATION_INFO(kLogSoftmaxOpName).set_backend_op_name(kLogSoft
|
|||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kMatrixDiagOpName).set_backend_op_name(kMatrixDiagDOpName);
|
||||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kMatrixDiagPartOpName).set_backend_op_name(kMatrixDiagPartDOpName);
|
||||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kMatrixSetDiagOpName).set_backend_op_name(kMatrixSetDiagDOpName);
|
||||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kMaxPool3DGradGradOpName).set_backend_op_name(kMaxPool3DGradGradDOpName);
|
||||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kIm2ColOpName).set_backend_op_name(kIm2colOpName);
|
||||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kNewIm2ColOpName).set_backend_op_name(kIm2colOpName);
|
||||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kParallelResizeBilinearOpName).set_backend_op_name(kSyncResizeBilinearV2OpName);
|
||||
|
|
Loading…
Reference in New Issue