Update code to use pault.ag/go/sniff package (#5038)

* Update code to use pault.ag/go/sniff package

* Update go dependencies
This commit is contained in:
Manuel Alejandro de Brito Fontes 2020-02-07 12:27:43 -03:00 committed by GitHub
parent 3e2bbbed3d
commit d0423c6d4f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 213 additions and 59 deletions

View file

@ -49,6 +49,7 @@ type Listener struct {
Listener net.Listener
ProxyHeaderTimeout time.Duration
SourceCheck SourceChecker
UnknownOK bool // allow PROXY UNKNOWN
}
// Conn is used to wrap and underlying connection which
@ -62,6 +63,7 @@ type Conn struct {
useConnAddr bool
once sync.Once
proxyHeaderTimeout time.Duration
unknownOK bool
}
// Accept waits for and returns the next connection to the listener.
@ -83,6 +85,7 @@ func (p *Listener) Accept() (net.Conn, error) {
}
newConn := NewConn(conn, p.ProxyHeaderTimeout)
newConn.useConnAddr = useConnAddr
newConn.unknownOK = p.UnknownOK
return newConn, nil
}
@ -119,6 +122,22 @@ func (p *Conn) Read(b []byte) (int, error) {
return p.bufReader.Read(b)
}
func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
if rf, ok := p.conn.(io.ReaderFrom); ok {
return rf.ReadFrom(r)
}
return io.Copy(p.conn, r)
}
func (p *Conn) WriteTo(w io.Writer) (int64, error) {
var err error
p.once.Do(func() { err = p.checkPrefix() })
if err != nil {
return 0, err
}
return p.bufReader.WriteTo(w)
}
func (p *Conn) Write(b []byte) (int, error) {
return p.conn.Write(b)
}
@ -209,13 +228,20 @@ func (p *Conn) checkPrefix() error {
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
parts := strings.Split(header, " ")
if len(parts) != 6 {
if len(parts) < 2 {
p.conn.Close()
return fmt.Errorf("Invalid header line: %s", header)
}
// Verify the type is known
switch parts[1] {
case "UNKNOWN":
if !p.unknownOK || len(parts) != 2 {
p.conn.Close()
return fmt.Errorf("Invalid UNKNOWN header line: %s", header)
}
p.useConnAddr = true
return nil
case "TCP4":
case "TCP6":
default:
@ -223,6 +249,11 @@ func (p *Conn) checkPrefix() error {
return fmt.Errorf("Unhandled address type: %s", parts[1])
}
if len(parts) != 6 {
p.conn.Close()
return fmt.Errorf("Invalid header line: %s", header)
}
// Parse out the source address
ip := net.ParseIP(parts[2])
if ip == nil {