fix MaxPool3DGradGrad bug

This commit is contained in:
liyiqi 2023-01-13 17:05:10 +08:00
parent 66b258841e
commit 432ae32409
6 changed files with 10 additions and 6 deletions

View File

@ -240,7 +240,9 @@
"axes": "axis"
},
"MaxPool3DGradGradD": {
"ksize": "kernel_size"
"ksize": "kernel_size",
"pads": "pad_list",
"data_format": "format"
}
},
"AttrDefaultValue": {

View File

@ -54,8 +54,8 @@ class BACKEND_EXPORT DataQueue {
virtual DataQueueStatus Front(std::vector<DataQueueItem> *data) const = 0;
virtual DataQueueStatus Pop() = 0;
virtual void SetThreadDevice() {}
virtual size_t Size() { return size_; }
virtual size_t Capacity() { return capacity_; }
virtual size_t Size() const { return size_; }
virtual size_t Capacity() const { return capacity_; }
protected:
const std::string channel_name_;

View File

@ -429,6 +429,7 @@ constexpr auto kMaximumGradGradOpName = "MaximumGradGrad";
constexpr auto kMaximumGradOpName = "MaximumGrad";
constexpr auto kMaximumOpName = "Maximum";
constexpr auto kMaxPool3DGradGradOpName = "MaxPool3DGradGrad";
constexpr auto kMaxPool3DGradGradDOpName = "MaxPool3DGradGradD";
constexpr auto kMaxPool3DGradOpName = "MaxPool3DGrad";
constexpr auto kMaxPool3DOpName = "MaxPool3D";
constexpr auto kMaxPoolGradOpName = "MaxPoolGrad";

View File

@ -68,7 +68,7 @@ class WingmanQueue : public DataQueue {
DataQueueStatus Pop() override;
bool IsEmpty() const override { return queue_.empty(); }
bool IsFull() const override { return false; }
size_t Size() override { return queue_.size(); }
size_t Size() const override { return queue_.size(); }
private:
std::queue<std::vector<DataQueueItem>> queue_;

View File

@ -98,7 +98,7 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
const BaseRef MaxPool3DGradGradFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto max_pool3d_grad_grad_prim = std::make_shared<Primitive>(kMaxPool3DGradGradOpName);
auto max_pool3d_grad_grad_prim = std::make_shared<Primitive>(kMaxPool3DGradGradDOpName);
return VectorRef({max_pool3d_grad_grad_prim, Xs});
}
@ -113,7 +113,7 @@ const AnfNodePtr MaxPool3DGradGradFission::Process(const FuncGraphPtr &graph, co
MS_LOG(INFO) << "The node " << cnode->DebugString() << " is not equal to " << kInputNum << " inputs";
return nullptr;
}
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kMaxPool3DGradGradOpName))};
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kMaxPool3DGradGradDOpName))};
auto assist_const = CreateValueNode(cnode);
(void)new_inputs.insert(new_inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
(void)new_inputs.emplace_back(assist_const);

View File

@ -89,6 +89,7 @@ REG_ASCEND_VM_OP_ADAPTATION_INFO(kLogSoftmaxOpName).set_backend_op_name(kLogSoft
REG_ASCEND_VM_OP_ADAPTATION_INFO(kMatrixDiagOpName).set_backend_op_name(kMatrixDiagDOpName);
REG_ASCEND_VM_OP_ADAPTATION_INFO(kMatrixDiagPartOpName).set_backend_op_name(kMatrixDiagPartDOpName);
REG_ASCEND_VM_OP_ADAPTATION_INFO(kMatrixSetDiagOpName).set_backend_op_name(kMatrixSetDiagDOpName);
REG_ASCEND_VM_OP_ADAPTATION_INFO(kMaxPool3DGradGradOpName).set_backend_op_name(kMaxPool3DGradGradDOpName);
REG_ASCEND_VM_OP_ADAPTATION_INFO(kIm2ColOpName).set_backend_op_name(kIm2colOpName);
REG_ASCEND_VM_OP_ADAPTATION_INFO(kNewIm2ColOpName).set_backend_op_name(kIm2colOpName);
REG_ASCEND_VM_OP_ADAPTATION_INFO(kParallelResizeBilinearOpName).set_backend_op_name(kSyncResizeBilinearV2OpName);