From 65842f749bbb09e779e4bc8d68ab884bacc86e63 Mon Sep 17 00:00:00 2001 From: Jonathan Rudenberg Date: Sun, 14 Sep 2014 20:29:27 -0400 Subject: [PATCH] netlink: Extract message checks into reusable method Signed-off-by: Jonathan Rudenberg --- netlink/netlink_linux.go | 62 ++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/netlink/netlink_linux.go b/netlink/netlink_linux.go index e3579eea..74980d72 100644 --- a/netlink/netlink_linux.go +++ b/netlink/netlink_linux.go @@ -3,6 +3,7 @@ package netlink import ( "encoding/binary" "fmt" + "io" "net" "sync/atomic" "syscall" @@ -322,35 +323,44 @@ func (s *NetlinkSocket) GetPid() (uint32, error) { return 0, ErrWrongSockType } -func (s *NetlinkSocket) HandleAck(seq uint32) error { +func (s *NetlinkSocket) CheckMessage(m syscall.NetlinkMessage, seq, pid uint32) error { + if m.Header.Seq != seq { + return fmt.Errorf("netlink: invalid seq %d, expected %d", m.Header.Seq, seq) + } + if m.Header.Pid != pid { + return fmt.Errorf("netlink: wrong pid %d, expected %d", m.Header.Pid, pid) + } + if m.Header.Type == syscall.NLMSG_DONE { + return io.EOF + } + if m.Header.Type == syscall.NLMSG_ERROR { + e := int32(native.Uint32(m.Data[0:4])) + if e == 0 { + return io.EOF + } + return syscall.Errno(-e) + } + return nil +} +func (s *NetlinkSocket) HandleAck(seq uint32) error { pid, err := s.GetPid() if err != nil { return err } -done: +outer: for { msgs, err := s.Receive() if err != nil { return err } for _, m := range msgs { - if m.Header.Seq != seq { - return fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, seq) - } - if m.Header.Pid != pid { - return fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid) - } - if m.Header.Type == syscall.NLMSG_DONE { - break done - } - if m.Header.Type == syscall.NLMSG_ERROR { - error := int32(native.Uint32(m.Data[0:4])) - if error == 0 { - break done + if err := s.CheckMessage(m, seq, pid); err != nil { + if err == io.EOF { + break outer } - return syscall.Errno(-error) + return err } } } @@ -781,28 +791,18 @@ func NetworkGetRoutes() ([]Route, error) { res := make([]Route, 0) -done: +outer: for { msgs, err := s.Receive() if err != nil { return nil, err } for _, m := range msgs { - if m.Header.Seq != wb.Seq { - return nil, fmt.Errorf("Wrong Seq nr %d, expected 1", m.Header.Seq) - } - if m.Header.Pid != pid { - return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid) - } - if m.Header.Type == syscall.NLMSG_DONE { - break done - } - if m.Header.Type == syscall.NLMSG_ERROR { - error := int32(native.Uint32(m.Data[0:4])) - if error == 0 { - break done + if err := s.CheckMessage(m, wb.Seq, pid); err != nil { + if err == io.EOF { + break outer } - return nil, syscall.Errno(-error) + return nil, err } if m.Header.Type != syscall.RTM_NEWROUTE { continue