upgrade version of lib/pq to v1.1.0 (#6640)
Adds SCRAM-SHA-256 authentication
This commit is contained in:
		
							parent
							
								
									83d6e5e3f8
								
							
						
					
					
						commit
						3fb038c53a
					
				
							
								
								
									
										2
									
								
								go.mod
								
								
								
								
							
							
						
						
									
										2
									
								
								go.mod
								
								
								
								
							| 
						 | 
				
			
			@ -74,7 +74,7 @@ require (
 | 
			
		|||
	github.com/klauspost/cpuid v0.0.0-20160302075316-09cded8978dc // indirect
 | 
			
		||||
	github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 // indirect
 | 
			
		||||
	github.com/lafriks/xormstore v1.0.0
 | 
			
		||||
	github.com/lib/pq v1.0.0
 | 
			
		||||
	github.com/lib/pq v1.1.0
 | 
			
		||||
	github.com/lunny/dingtalk_webhook v0.0.0-20171025031554-e3534c89ef96
 | 
			
		||||
	github.com/lunny/levelqueue v0.0.0-20190217115915-02b525a4418e
 | 
			
		||||
	github.com/lunny/log v0.0.0-20160921050905-7887c61bf0de // indirect
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										2
									
								
								go.sum
								
								
								
								
							
							
						
						
									
										2
									
								
								go.sum
								
								
								
								
							| 
						 | 
				
			
			@ -198,6 +198,8 @@ github.com/lafriks/xormstore v1.0.0 h1:P/IJzNSIpjXl/Up3o2Td5ZU/x4v6DEKLMaPQJGtmJ
 | 
			
		|||
github.com/lafriks/xormstore v1.0.0/go.mod h1:dD8vHNRfEp3Uy+JvX9cMi2SXcRKJ0x4pYKsZuy843Ic=
 | 
			
		||||
github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A=
 | 
			
		||||
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
 | 
			
		||||
github.com/lib/pq v1.1.0 h1:/5u4a+KGJptBRqGzPvYQL9p0d/tPR4S31+Tnzj9lEO4=
 | 
			
		||||
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
 | 
			
		||||
github.com/lunny/dingtalk_webhook v0.0.0-20171025031554-e3534c89ef96 h1:uNwtsDp7ci48vBTTxDuwcoTXz4lwtDTe7TjCQ0noaWY=
 | 
			
		||||
github.com/lunny/dingtalk_webhook v0.0.0-20171025031554-e3534c89ef96/go.mod h1:mmIfjCSQlGYXmJ95jFN84AkQFnVABtKuJL8IrzwvUKQ=
 | 
			
		||||
github.com/lunny/levelqueue v0.0.0-20190217115915-02b525a4418e h1:GSprKUrG9wNgwQgROvjPGXmcZrg4OLslOuZGB0uJjx8=
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -80,7 +80,7 @@ megacheck_install() {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
golint_install() {
 | 
			
		||||
	go get github.com/golang/lint/golint
 | 
			
		||||
	go get golang.org/x/lint/golint
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
$1
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,9 +1,9 @@
 | 
			
		|||
language: go
 | 
			
		||||
 | 
			
		||||
go:
 | 
			
		||||
  - 1.8.x
 | 
			
		||||
  - 1.9.x
 | 
			
		||||
  - 1.10.x
 | 
			
		||||
  - 1.11.x
 | 
			
		||||
  - master
 | 
			
		||||
 | 
			
		||||
sudo: true
 | 
			
		||||
| 
						 | 
				
			
			@ -44,7 +44,7 @@ script:
 | 
			
		|||
  - >
 | 
			
		||||
    goimports -d -e $(find -name '*.go') | awk '{ print } END { exit NR == 0 ? 0 : 1 }'
 | 
			
		||||
  - go vet ./...
 | 
			
		||||
  - megacheck -go 1.8 ./...
 | 
			
		||||
  - megacheck -go 1.9 ./...
 | 
			
		||||
  - golint ./...
 | 
			
		||||
  - PQTEST_BINARY_PARAMETERS=no  go test -race -v ./...
 | 
			
		||||
  - PQTEST_BINARY_PARAMETERS=yes go test -race -v ./...
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,7 +10,7 @@
 | 
			
		|||
## Docs
 | 
			
		||||
 | 
			
		||||
For detailed documentation and basic usage examples, please see the package
 | 
			
		||||
documentation at <http://godoc.org/github.com/lib/pq>.
 | 
			
		||||
documentation at <https://godoc.org/github.com/lib/pq>.
 | 
			
		||||
 | 
			
		||||
## Tests
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,7 +2,9 @@ package pq
 | 
			
		|||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/md5"
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
| 
						 | 
				
			
			@ -20,6 +22,7 @@ import (
 | 
			
		|||
	"unicode"
 | 
			
		||||
 | 
			
		||||
	"github.com/lib/pq/oid"
 | 
			
		||||
	"github.com/lib/pq/scram"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Common error types
 | 
			
		||||
| 
						 | 
				
			
			@ -89,13 +92,24 @@ type Dialer interface {
 | 
			
		|||
	DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type defaultDialer struct{}
 | 
			
		||||
 | 
			
		||||
func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) {
 | 
			
		||||
	return net.Dial(ntw, addr)
 | 
			
		||||
type DialerContext interface {
 | 
			
		||||
	DialContext(ctx context.Context, network, address string) (net.Conn, error)
 | 
			
		||||
}
 | 
			
		||||
func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
 | 
			
		||||
	return net.DialTimeout(ntw, addr, timeout)
 | 
			
		||||
 | 
			
		||||
type defaultDialer struct {
 | 
			
		||||
	d net.Dialer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
 | 
			
		||||
	return d.d.Dial(network, address)
 | 
			
		||||
}
 | 
			
		||||
func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
 | 
			
		||||
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
	return d.DialContext(ctx, network, address)
 | 
			
		||||
}
 | 
			
		||||
func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
 | 
			
		||||
	return d.d.DialContext(ctx, network, address)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type conn struct {
 | 
			
		||||
| 
						 | 
				
			
			@ -244,90 +258,35 @@ func (cn *conn) writeBuf(b byte) *writeBuf {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Open opens a new connection to the database. name is a connection string.
 | 
			
		||||
// Open opens a new connection to the database. dsn is a connection string.
 | 
			
		||||
// Most users should only use it through database/sql package from the standard
 | 
			
		||||
// library.
 | 
			
		||||
func Open(name string) (_ driver.Conn, err error) {
 | 
			
		||||
	return DialOpen(defaultDialer{}, name)
 | 
			
		||||
func Open(dsn string) (_ driver.Conn, err error) {
 | 
			
		||||
	return DialOpen(defaultDialer{}, dsn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DialOpen opens a new connection to the database using a dialer.
 | 
			
		||||
func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
 | 
			
		||||
func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
 | 
			
		||||
	c, err := NewConnector(dsn)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	c.dialer = d
 | 
			
		||||
	return c.open(context.Background())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
 | 
			
		||||
	// Handle any panics during connection initialization.  Note that we
 | 
			
		||||
	// specifically do *not* want to use errRecover(), as that would turn any
 | 
			
		||||
	// connection errors into ErrBadConns, hiding the real error message from
 | 
			
		||||
	// the user.
 | 
			
		||||
	defer errRecoverNoErrBadConn(&err)
 | 
			
		||||
 | 
			
		||||
	o := make(values)
 | 
			
		||||
	o := c.opts
 | 
			
		||||
 | 
			
		||||
	// A number of defaults are applied here, in this order:
 | 
			
		||||
	//
 | 
			
		||||
	// * Very low precedence defaults applied in every situation
 | 
			
		||||
	// * Environment variables
 | 
			
		||||
	// * Explicitly passed connection information
 | 
			
		||||
	o["host"] = "localhost"
 | 
			
		||||
	o["port"] = "5432"
 | 
			
		||||
	// N.B.: Extra float digits should be set to 3, but that breaks
 | 
			
		||||
	// Postgres 8.4 and older, where the max is 2.
 | 
			
		||||
	o["extra_float_digits"] = "2"
 | 
			
		||||
	for k, v := range parseEnviron(os.Environ()) {
 | 
			
		||||
		o[k] = v
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
 | 
			
		||||
		name, err = ParseURL(name)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := parseOpts(name, o); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Use the "fallback" application name if necessary
 | 
			
		||||
	if fallback, ok := o["fallback_application_name"]; ok {
 | 
			
		||||
		if _, ok := o["application_name"]; !ok {
 | 
			
		||||
			o["application_name"] = fallback
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// We can't work with any client_encoding other than UTF-8 currently.
 | 
			
		||||
	// However, we have historically allowed the user to set it to UTF-8
 | 
			
		||||
	// explicitly, and there's no reason to break such programs, so allow that.
 | 
			
		||||
	// Note that the "options" setting could also set client_encoding, but
 | 
			
		||||
	// parsing its value is not worth it.  Instead, we always explicitly send
 | 
			
		||||
	// client_encoding as a separate run-time parameter, which should override
 | 
			
		||||
	// anything set in options.
 | 
			
		||||
	if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
 | 
			
		||||
		return nil, errors.New("client_encoding must be absent or 'UTF8'")
 | 
			
		||||
	}
 | 
			
		||||
	o["client_encoding"] = "UTF8"
 | 
			
		||||
	// DateStyle needs a similar treatment.
 | 
			
		||||
	if datestyle, ok := o["datestyle"]; ok {
 | 
			
		||||
		if datestyle != "ISO, MDY" {
 | 
			
		||||
			panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v",
 | 
			
		||||
				"ISO, MDY", datestyle))
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		o["datestyle"] = "ISO, MDY"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If a user is not provided by any other means, the last
 | 
			
		||||
	// resort is to use the current operating system provided user
 | 
			
		||||
	// name.
 | 
			
		||||
	if _, ok := o["user"]; !ok {
 | 
			
		||||
		u, err := userCurrent()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		o["user"] = u
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cn := &conn{
 | 
			
		||||
	cn = &conn{
 | 
			
		||||
		opts:   o,
 | 
			
		||||
		dialer: d,
 | 
			
		||||
		dialer: c.dialer,
 | 
			
		||||
	}
 | 
			
		||||
	err = cn.handleDriverSettings(o)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -335,7 +294,7 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
 | 
			
		|||
	}
 | 
			
		||||
	cn.handlePgpass(o)
 | 
			
		||||
 | 
			
		||||
	cn.c, err = dial(d, o)
 | 
			
		||||
	cn.c, err = dial(ctx, c.dialer, o)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -364,10 +323,10 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
 | 
			
		|||
	return cn, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func dial(d Dialer, o values) (net.Conn, error) {
 | 
			
		||||
	ntw, addr := network(o)
 | 
			
		||||
func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
 | 
			
		||||
	network, address := network(o)
 | 
			
		||||
	// SSL is not necessary or supported over UNIX domain sockets
 | 
			
		||||
	if ntw == "unix" {
 | 
			
		||||
	if network == "unix" {
 | 
			
		||||
		o["sslmode"] = "disable"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -378,19 +337,30 @@ func dial(d Dialer, o values) (net.Conn, error) {
 | 
			
		|||
			return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
		duration := time.Duration(seconds) * time.Second
 | 
			
		||||
 | 
			
		||||
		// connect_timeout should apply to the entire connection establishment
 | 
			
		||||
		// procedure, so we both use a timeout for the TCP connection
 | 
			
		||||
		// establishment and set a deadline for doing the initial handshake.
 | 
			
		||||
		// The deadline is then reset after startup() is done.
 | 
			
		||||
		deadline := time.Now().Add(duration)
 | 
			
		||||
		conn, err := d.DialTimeout(ntw, addr, duration)
 | 
			
		||||
		var conn net.Conn
 | 
			
		||||
		if dctx, ok := d.(DialerContext); ok {
 | 
			
		||||
			ctx, cancel := context.WithTimeout(ctx, duration)
 | 
			
		||||
			defer cancel()
 | 
			
		||||
			conn, err = dctx.DialContext(ctx, network, address)
 | 
			
		||||
		} else {
 | 
			
		||||
			conn, err = d.DialTimeout(network, address, duration)
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		err = conn.SetDeadline(deadline)
 | 
			
		||||
		return conn, err
 | 
			
		||||
	}
 | 
			
		||||
	return d.Dial(ntw, addr)
 | 
			
		||||
	if dctx, ok := d.(DialerContext); ok {
 | 
			
		||||
		return dctx.DialContext(ctx, network, address)
 | 
			
		||||
	}
 | 
			
		||||
	return d.Dial(network, address)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func network(o values) (string, string) {
 | 
			
		||||
| 
						 | 
				
			
			@ -704,7 +674,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
 | 
			
		|||
			// res might be non-nil here if we received a previous
 | 
			
		||||
			// CommandComplete, but that's fine; just overwrite it
 | 
			
		||||
			res = &rows{cn: cn}
 | 
			
		||||
			res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
 | 
			
		||||
			res.rowsHeader = parsePortalRowDescribe(r)
 | 
			
		||||
 | 
			
		||||
			// To work around a bug in QueryRow in Go 1.2 and earlier, wait
 | 
			
		||||
			// until the first DataRow has been received.
 | 
			
		||||
| 
						 | 
				
			
			@ -861,7 +831,7 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
 | 
			
		|||
		cn.readParseResponse()
 | 
			
		||||
		cn.readBindResponse()
 | 
			
		||||
		rows := &rows{cn: cn}
 | 
			
		||||
		rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
 | 
			
		||||
		rows.rowsHeader = cn.readPortalDescribeResponse()
 | 
			
		||||
		cn.postExecuteWorkaround()
 | 
			
		||||
		return rows, nil
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -869,9 +839,7 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
 | 
			
		|||
	st.exec(args)
 | 
			
		||||
	return &rows{
 | 
			
		||||
		cn:         cn,
 | 
			
		||||
		colNames: st.colNames,
 | 
			
		||||
		colTyps:  st.colTyps,
 | 
			
		||||
		colFmts:  st.colFmts,
 | 
			
		||||
		rowsHeader: st.rowsHeader,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -992,7 +960,6 @@ func (cn *conn) recv() (t byte, r *readBuf) {
 | 
			
		|||
		if err != nil {
 | 
			
		||||
			panic(err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		switch t {
 | 
			
		||||
		case 'E':
 | 
			
		||||
			panic(parseError(r))
 | 
			
		||||
| 
						 | 
				
			
			@ -1163,6 +1130,55 @@ func (cn *conn) auth(r *readBuf, o values) {
 | 
			
		|||
		if r.int32() != 0 {
 | 
			
		||||
			errorf("unexpected authentication response: %q", t)
 | 
			
		||||
		}
 | 
			
		||||
	case 10:
 | 
			
		||||
		sc := scram.NewClient(sha256.New, o["user"], o["password"])
 | 
			
		||||
		sc.Step(nil)
 | 
			
		||||
		if sc.Err() != nil {
 | 
			
		||||
			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
 | 
			
		||||
		}
 | 
			
		||||
		scOut := sc.Out()
 | 
			
		||||
 | 
			
		||||
		w := cn.writeBuf('p')
 | 
			
		||||
		w.string("SCRAM-SHA-256")
 | 
			
		||||
		w.int32(len(scOut))
 | 
			
		||||
		w.bytes(scOut)
 | 
			
		||||
		cn.send(w)
 | 
			
		||||
 | 
			
		||||
		t, r := cn.recv()
 | 
			
		||||
		if t != 'R' {
 | 
			
		||||
			errorf("unexpected password response: %q", t)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if r.int32() != 11 {
 | 
			
		||||
			errorf("unexpected authentication response: %q", t)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		nextStep := r.next(len(*r))
 | 
			
		||||
		sc.Step(nextStep)
 | 
			
		||||
		if sc.Err() != nil {
 | 
			
		||||
			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		scOut = sc.Out()
 | 
			
		||||
		w = cn.writeBuf('p')
 | 
			
		||||
		w.bytes(scOut)
 | 
			
		||||
		cn.send(w)
 | 
			
		||||
 | 
			
		||||
		t, r = cn.recv()
 | 
			
		||||
		if t != 'R' {
 | 
			
		||||
			errorf("unexpected password response: %q", t)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if r.int32() != 12 {
 | 
			
		||||
			errorf("unexpected authentication response: %q", t)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		nextStep = r.next(len(*r))
 | 
			
		||||
		sc.Step(nextStep)
 | 
			
		||||
		if sc.Err() != nil {
 | 
			
		||||
			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		errorf("unknown authentication response: %d", code)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -1182,10 +1198,8 @@ var colFmtDataAllText = []byte{0, 0}
 | 
			
		|||
type stmt struct {
 | 
			
		||||
	cn   *conn
 | 
			
		||||
	name string
 | 
			
		||||
	colNames   []string
 | 
			
		||||
	colFmts    []format
 | 
			
		||||
	rowsHeader
 | 
			
		||||
	colFmtData []byte
 | 
			
		||||
	colTyps    []fieldDesc
 | 
			
		||||
	paramTyps  []oid.Oid
 | 
			
		||||
	closed     bool
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -1232,9 +1246,7 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
 | 
			
		|||
	st.exec(v)
 | 
			
		||||
	return &rows{
 | 
			
		||||
		cn:         st.cn,
 | 
			
		||||
		colNames: st.colNames,
 | 
			
		||||
		colTyps:  st.colTyps,
 | 
			
		||||
		colFmts:  st.colFmts,
 | 
			
		||||
		rowsHeader: st.rowsHeader,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1344,16 +1356,22 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
 | 
			
		|||
	return driver.RowsAffected(n), commandTag
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type rows struct {
 | 
			
		||||
	cn       *conn
 | 
			
		||||
	finish   func()
 | 
			
		||||
type rowsHeader struct {
 | 
			
		||||
	colNames []string
 | 
			
		||||
	colTyps  []fieldDesc
 | 
			
		||||
	colFmts  []format
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type rows struct {
 | 
			
		||||
	cn     *conn
 | 
			
		||||
	finish func()
 | 
			
		||||
	rowsHeader
 | 
			
		||||
	done   bool
 | 
			
		||||
	rb     readBuf
 | 
			
		||||
	result driver.Result
 | 
			
		||||
	tag    string
 | 
			
		||||
 | 
			
		||||
	next *rowsHeader
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rs *rows) Close() error {
 | 
			
		||||
| 
						 | 
				
			
			@ -1440,7 +1458,8 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
 | 
			
		|||
			}
 | 
			
		||||
			return
 | 
			
		||||
		case 'T':
 | 
			
		||||
			rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb)
 | 
			
		||||
			next := parsePortalRowDescribe(&rs.rb)
 | 
			
		||||
			rs.next = &next
 | 
			
		||||
			return io.EOF
 | 
			
		||||
		default:
 | 
			
		||||
			errorf("unexpected message after execute: %q", t)
 | 
			
		||||
| 
						 | 
				
			
			@ -1449,10 +1468,16 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func (rs *rows) HasNextResultSet() bool {
 | 
			
		||||
	return !rs.done
 | 
			
		||||
	hasNext := rs.next != nil && !rs.done
 | 
			
		||||
	return hasNext
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rs *rows) NextResultSet() error {
 | 
			
		||||
	if rs.next == nil {
 | 
			
		||||
		return io.EOF
 | 
			
		||||
	}
 | 
			
		||||
	rs.rowsHeader = *rs.next
 | 
			
		||||
	rs.next = nil
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1630,13 +1655,13 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) {
 | 
			
		||||
func (cn *conn) readPortalDescribeResponse() rowsHeader {
 | 
			
		||||
	t, r := cn.recv1()
 | 
			
		||||
	switch t {
 | 
			
		||||
	case 'T':
 | 
			
		||||
		return parsePortalRowDescribe(r)
 | 
			
		||||
	case 'n':
 | 
			
		||||
		return nil, nil, nil
 | 
			
		||||
		return rowsHeader{}
 | 
			
		||||
	case 'E':
 | 
			
		||||
		err := parseError(r)
 | 
			
		||||
		cn.readReadyForQuery()
 | 
			
		||||
| 
						 | 
				
			
			@ -1742,11 +1767,11 @@ func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDe
 | 
			
		|||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) {
 | 
			
		||||
func parsePortalRowDescribe(r *readBuf) rowsHeader {
 | 
			
		||||
	n := r.int16()
 | 
			
		||||
	colNames = make([]string, n)
 | 
			
		||||
	colFmts = make([]format, n)
 | 
			
		||||
	colTyps = make([]fieldDesc, n)
 | 
			
		||||
	colNames := make([]string, n)
 | 
			
		||||
	colFmts := make([]format, n)
 | 
			
		||||
	colTyps := make([]fieldDesc, n)
 | 
			
		||||
	for i := range colNames {
 | 
			
		||||
		colNames[i] = r.string()
 | 
			
		||||
		r.next(6)
 | 
			
		||||
| 
						 | 
				
			
			@ -1755,7 +1780,11 @@ func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, co
 | 
			
		|||
		colTyps[i].Mod = r.int32()
 | 
			
		||||
		colFmts[i] = format(r.int16())
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
	return rowsHeader{
 | 
			
		||||
		colNames: colNames,
 | 
			
		||||
		colFmts:  colFmts,
 | 
			
		||||
		colTyps:  colTyps,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// parseEnviron tries to mimic some of libpq's environment handling
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,3 @@
 | 
			
		|||
// +build go1.8
 | 
			
		||||
 | 
			
		||||
package pq
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
| 
						 | 
				
			
			@ -9,6 +7,7 @@ import (
 | 
			
		|||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Implement the "QueryerContext" interface
 | 
			
		||||
| 
						 | 
				
			
			@ -76,13 +75,32 @@ func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx,
 | 
			
		|||
	return tx, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cn *conn) Ping(ctx context.Context) error {
 | 
			
		||||
	if finish := cn.watchCancel(ctx); finish != nil {
 | 
			
		||||
		defer finish()
 | 
			
		||||
	}
 | 
			
		||||
	rows, err := cn.simpleQuery("SELECT 'lib/pq ping test';")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger
 | 
			
		||||
	}
 | 
			
		||||
	rows.Close()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cn *conn) watchCancel(ctx context.Context) func() {
 | 
			
		||||
	if done := ctx.Done(); done != nil {
 | 
			
		||||
		finished := make(chan struct{})
 | 
			
		||||
		go func() {
 | 
			
		||||
			select {
 | 
			
		||||
			case <-done:
 | 
			
		||||
				_ = cn.cancel()
 | 
			
		||||
				// At this point the function level context is canceled,
 | 
			
		||||
				// so it must not be used for the additional network
 | 
			
		||||
				// request to cancel the query.
 | 
			
		||||
				// Create a new context to pass into the dial.
 | 
			
		||||
				ctxCancel, cancel := context.WithTimeout(context.Background(), time.Second*10)
 | 
			
		||||
				defer cancel()
 | 
			
		||||
 | 
			
		||||
				_ = cn.cancel(ctxCancel)
 | 
			
		||||
				finished <- struct{}{}
 | 
			
		||||
			case <-finished:
 | 
			
		||||
			}
 | 
			
		||||
| 
						 | 
				
			
			@ -97,8 +115,8 @@ func (cn *conn) watchCancel(ctx context.Context) func() {
 | 
			
		|||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cn *conn) cancel() error {
 | 
			
		||||
	c, err := dial(cn.dialer, cn.opts)
 | 
			
		||||
func (cn *conn) cancel(ctx context.Context) error {
 | 
			
		||||
	c, err := dial(ctx, cn.dialer, cn.opts)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,10 +1,12 @@
 | 
			
		|||
// +build go1.10
 | 
			
		||||
 | 
			
		||||
package pq
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Connector represents a fixed configuration for the pq driver with a given
 | 
			
		||||
| 
						 | 
				
			
			@ -14,30 +16,95 @@ import (
 | 
			
		|||
//
 | 
			
		||||
// See https://golang.org/pkg/database/sql/driver/#Connector.
 | 
			
		||||
// See https://golang.org/pkg/database/sql/#OpenDB.
 | 
			
		||||
type connector struct {
 | 
			
		||||
	name string
 | 
			
		||||
type Connector struct {
 | 
			
		||||
	opts   values
 | 
			
		||||
	dialer Dialer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Connect returns a connection to the database using the fixed configuration
 | 
			
		||||
// of this Connector. Context is not used.
 | 
			
		||||
func (c *connector) Connect(_ context.Context) (driver.Conn, error) {
 | 
			
		||||
	return (&Driver{}).Open(c.name)
 | 
			
		||||
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
 | 
			
		||||
	return c.open(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Driver returnst the underlying driver of this Connector.
 | 
			
		||||
func (c *connector) Driver() driver.Driver {
 | 
			
		||||
func (c *Connector) Driver() driver.Driver {
 | 
			
		||||
	return &Driver{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ driver.Connector = &connector{}
 | 
			
		||||
 | 
			
		||||
// NewConnector returns a connector for the pq driver in a fixed configuration
 | 
			
		||||
// with the given name. The returned connector can be used to create any number
 | 
			
		||||
// with the given dsn. The returned connector can be used to create any number
 | 
			
		||||
// of equivalent Conn's. The returned connector is intended to be used with
 | 
			
		||||
// database/sql.OpenDB.
 | 
			
		||||
//
 | 
			
		||||
// See https://golang.org/pkg/database/sql/driver/#Connector.
 | 
			
		||||
// See https://golang.org/pkg/database/sql/#OpenDB.
 | 
			
		||||
func NewConnector(name string) (driver.Connector, error) {
 | 
			
		||||
	return &connector{name: name}, nil
 | 
			
		||||
func NewConnector(dsn string) (*Connector, error) {
 | 
			
		||||
	var err error
 | 
			
		||||
	o := make(values)
 | 
			
		||||
 | 
			
		||||
	// A number of defaults are applied here, in this order:
 | 
			
		||||
	//
 | 
			
		||||
	// * Very low precedence defaults applied in every situation
 | 
			
		||||
	// * Environment variables
 | 
			
		||||
	// * Explicitly passed connection information
 | 
			
		||||
	o["host"] = "localhost"
 | 
			
		||||
	o["port"] = "5432"
 | 
			
		||||
	// N.B.: Extra float digits should be set to 3, but that breaks
 | 
			
		||||
	// Postgres 8.4 and older, where the max is 2.
 | 
			
		||||
	o["extra_float_digits"] = "2"
 | 
			
		||||
	for k, v := range parseEnviron(os.Environ()) {
 | 
			
		||||
		o[k] = v
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
 | 
			
		||||
		dsn, err = ParseURL(dsn)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := parseOpts(dsn, o); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Use the "fallback" application name if necessary
 | 
			
		||||
	if fallback, ok := o["fallback_application_name"]; ok {
 | 
			
		||||
		if _, ok := o["application_name"]; !ok {
 | 
			
		||||
			o["application_name"] = fallback
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// We can't work with any client_encoding other than UTF-8 currently.
 | 
			
		||||
	// However, we have historically allowed the user to set it to UTF-8
 | 
			
		||||
	// explicitly, and there's no reason to break such programs, so allow that.
 | 
			
		||||
	// Note that the "options" setting could also set client_encoding, but
 | 
			
		||||
	// parsing its value is not worth it.  Instead, we always explicitly send
 | 
			
		||||
	// client_encoding as a separate run-time parameter, which should override
 | 
			
		||||
	// anything set in options.
 | 
			
		||||
	if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
 | 
			
		||||
		return nil, errors.New("client_encoding must be absent or 'UTF8'")
 | 
			
		||||
	}
 | 
			
		||||
	o["client_encoding"] = "UTF8"
 | 
			
		||||
	// DateStyle needs a similar treatment.
 | 
			
		||||
	if datestyle, ok := o["datestyle"]; ok {
 | 
			
		||||
		if datestyle != "ISO, MDY" {
 | 
			
		||||
			return nil, fmt.Errorf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		o["datestyle"] = "ISO, MDY"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If a user is not provided by any other means, the last
 | 
			
		||||
	// resort is to use the current operating system provided user
 | 
			
		||||
	// name.
 | 
			
		||||
	if _, ok := o["user"]; !ok {
 | 
			
		||||
		u, err := userCurrent()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		o["user"] = u
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &Connector{opts: o, dialer: defaultDialer{}}, nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -239,7 +239,7 @@ for more information).  Note that the channel name will be truncated to 63
 | 
			
		|||
bytes by the PostgreSQL server.
 | 
			
		||||
 | 
			
		||||
You can find a complete, working example of Listener usage at
 | 
			
		||||
http://godoc.org/github.com/lib/pq/example/listen.
 | 
			
		||||
https://godoc.org/github.com/lib/pq/example/listen.
 | 
			
		||||
 | 
			
		||||
*/
 | 
			
		||||
package pq
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,264 @@
 | 
			
		|||
// Copyright (c) 2014 - Gustavo Niemeyer <gustavo@niemeyer.net>
 | 
			
		||||
//
 | 
			
		||||
// All rights reserved.
 | 
			
		||||
//
 | 
			
		||||
// Redistribution and use in source and binary forms, with or without
 | 
			
		||||
// modification, are permitted provided that the following conditions are met:
 | 
			
		||||
//
 | 
			
		||||
// 1. Redistributions of source code must retain the above copyright notice, this
 | 
			
		||||
//    list of conditions and the following disclaimer.
 | 
			
		||||
// 2. Redistributions in binary form must reproduce the above copyright notice,
 | 
			
		||||
//    this list of conditions and the following disclaimer in the documentation
 | 
			
		||||
//    and/or other materials provided with the distribution.
 | 
			
		||||
//
 | 
			
		||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 | 
			
		||||
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 | 
			
		||||
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 | 
			
		||||
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
 | 
			
		||||
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 | 
			
		||||
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 | 
			
		||||
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 | 
			
		||||
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 | 
			
		||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 | 
			
		||||
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | 
			
		||||
 | 
			
		||||
// Pacakage scram implements a SCRAM-{SHA-1,etc} client per RFC5802.
 | 
			
		||||
//
 | 
			
		||||
// http://tools.ietf.org/html/rfc5802
 | 
			
		||||
//
 | 
			
		||||
package scram
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"crypto/hmac"
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"hash"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc).
 | 
			
		||||
//
 | 
			
		||||
// A Client may be used within a SASL conversation with logic resembling:
 | 
			
		||||
//
 | 
			
		||||
//    var in []byte
 | 
			
		||||
//    var client = scram.NewClient(sha1.New, user, pass)
 | 
			
		||||
//    for client.Step(in) {
 | 
			
		||||
//            out := client.Out()
 | 
			
		||||
//            // send out to server
 | 
			
		||||
//            in := serverOut
 | 
			
		||||
//    }
 | 
			
		||||
//    if client.Err() != nil {
 | 
			
		||||
//            // auth failed
 | 
			
		||||
//    }
 | 
			
		||||
//
 | 
			
		||||
type Client struct {
 | 
			
		||||
	newHash func() hash.Hash
 | 
			
		||||
 | 
			
		||||
	user string
 | 
			
		||||
	pass string
 | 
			
		||||
	step int
 | 
			
		||||
	out  bytes.Buffer
 | 
			
		||||
	err  error
 | 
			
		||||
 | 
			
		||||
	clientNonce []byte
 | 
			
		||||
	serverNonce []byte
 | 
			
		||||
	saltedPass  []byte
 | 
			
		||||
	authMsg     bytes.Buffer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewClient returns a new SCRAM-* client with the provided hash algorithm.
 | 
			
		||||
//
 | 
			
		||||
// For SCRAM-SHA-256, for example, use:
 | 
			
		||||
//
 | 
			
		||||
//    client := scram.NewClient(sha256.New, user, pass)
 | 
			
		||||
//
 | 
			
		||||
func NewClient(newHash func() hash.Hash, user, pass string) *Client {
 | 
			
		||||
	c := &Client{
 | 
			
		||||
		newHash: newHash,
 | 
			
		||||
		user:    user,
 | 
			
		||||
		pass:    pass,
 | 
			
		||||
	}
 | 
			
		||||
	c.out.Grow(256)
 | 
			
		||||
	c.authMsg.Grow(256)
 | 
			
		||||
	return c
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Out returns the data to be sent to the server in the current step.
 | 
			
		||||
func (c *Client) Out() []byte {
 | 
			
		||||
	if c.out.Len() == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return c.out.Bytes()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Err returns the error that ocurred, or nil if there were no errors.
 | 
			
		||||
func (c *Client) Err() error {
 | 
			
		||||
	return c.err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetNonce sets the client nonce to the provided value.
 | 
			
		||||
// If not set, the nonce is generated automatically out of crypto/rand on the first step.
 | 
			
		||||
func (c *Client) SetNonce(nonce []byte) {
 | 
			
		||||
	c.clientNonce = nonce
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var escaper = strings.NewReplacer("=", "=3D", ",", "=2C")
 | 
			
		||||
 | 
			
		||||
// Step processes the incoming data from the server and makes the
 | 
			
		||||
// next round of data for the server available via Client.Out.
 | 
			
		||||
// Step returns false if there are no errors and more data is
 | 
			
		||||
// still expected.
 | 
			
		||||
func (c *Client) Step(in []byte) bool {
 | 
			
		||||
	c.out.Reset()
 | 
			
		||||
	if c.step > 2 || c.err != nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	c.step++
 | 
			
		||||
	switch c.step {
 | 
			
		||||
	case 1:
 | 
			
		||||
		c.err = c.step1(in)
 | 
			
		||||
	case 2:
 | 
			
		||||
		c.err = c.step2(in)
 | 
			
		||||
	case 3:
 | 
			
		||||
		c.err = c.step3(in)
 | 
			
		||||
	}
 | 
			
		||||
	return c.step > 2 || c.err != nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) step1(in []byte) error {
 | 
			
		||||
	if len(c.clientNonce) == 0 {
 | 
			
		||||
		const nonceLen = 16
 | 
			
		||||
		buf := make([]byte, nonceLen+b64.EncodedLen(nonceLen))
 | 
			
		||||
		if _, err := rand.Read(buf[:nonceLen]); err != nil {
 | 
			
		||||
			return fmt.Errorf("cannot read random SCRAM-SHA-256 nonce from operating system: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		c.clientNonce = buf[nonceLen:]
 | 
			
		||||
		b64.Encode(c.clientNonce, buf[:nonceLen])
 | 
			
		||||
	}
 | 
			
		||||
	c.authMsg.WriteString("n=")
 | 
			
		||||
	escaper.WriteString(&c.authMsg, c.user)
 | 
			
		||||
	c.authMsg.WriteString(",r=")
 | 
			
		||||
	c.authMsg.Write(c.clientNonce)
 | 
			
		||||
 | 
			
		||||
	c.out.WriteString("n,,")
 | 
			
		||||
	c.out.Write(c.authMsg.Bytes())
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var b64 = base64.StdEncoding
 | 
			
		||||
 | 
			
		||||
func (c *Client) step2(in []byte) error {
 | 
			
		||||
	c.authMsg.WriteByte(',')
 | 
			
		||||
	c.authMsg.Write(in)
 | 
			
		||||
 | 
			
		||||
	fields := bytes.Split(in, []byte(","))
 | 
			
		||||
	if len(fields) != 3 {
 | 
			
		||||
		return fmt.Errorf("expected 3 fields in first SCRAM-SHA-256 server message, got %d: %q", len(fields), in)
 | 
			
		||||
	}
 | 
			
		||||
	if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 {
 | 
			
		||||
		return fmt.Errorf("server sent an invalid SCRAM-SHA-256 nonce: %q", fields[0])
 | 
			
		||||
	}
 | 
			
		||||
	if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 {
 | 
			
		||||
		return fmt.Errorf("server sent an invalid SCRAM-SHA-256 salt: %q", fields[1])
 | 
			
		||||
	}
 | 
			
		||||
	if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 {
 | 
			
		||||
		return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.serverNonce = fields[0][2:]
 | 
			
		||||
	if !bytes.HasPrefix(c.serverNonce, c.clientNonce) {
 | 
			
		||||
		return fmt.Errorf("server SCRAM-SHA-256 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	salt := make([]byte, b64.DecodedLen(len(fields[1][2:])))
 | 
			
		||||
	n, err := b64.Decode(salt, fields[1][2:])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("cannot decode SCRAM-SHA-256 salt sent by server: %q", fields[1])
 | 
			
		||||
	}
 | 
			
		||||
	salt = salt[:n]
 | 
			
		||||
	iterCount, err := strconv.Atoi(string(fields[2][2:]))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2])
 | 
			
		||||
	}
 | 
			
		||||
	c.saltPassword(salt, iterCount)
 | 
			
		||||
 | 
			
		||||
	c.authMsg.WriteString(",c=biws,r=")
 | 
			
		||||
	c.authMsg.Write(c.serverNonce)
 | 
			
		||||
 | 
			
		||||
	c.out.WriteString("c=biws,r=")
 | 
			
		||||
	c.out.Write(c.serverNonce)
 | 
			
		||||
	c.out.WriteString(",p=")
 | 
			
		||||
	c.out.Write(c.clientProof())
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) step3(in []byte) error {
 | 
			
		||||
	var isv, ise bool
 | 
			
		||||
	var fields = bytes.Split(in, []byte(","))
 | 
			
		||||
	if len(fields) == 1 {
 | 
			
		||||
		isv = bytes.HasPrefix(fields[0], []byte("v="))
 | 
			
		||||
		ise = bytes.HasPrefix(fields[0], []byte("e="))
 | 
			
		||||
	}
 | 
			
		||||
	if ise {
 | 
			
		||||
		return fmt.Errorf("SCRAM-SHA-256 authentication error: %s", fields[0][2:])
 | 
			
		||||
	} else if !isv {
 | 
			
		||||
		return fmt.Errorf("unsupported SCRAM-SHA-256 final message from server: %q", in)
 | 
			
		||||
	}
 | 
			
		||||
	if !bytes.Equal(c.serverSignature(), fields[0][2:]) {
 | 
			
		||||
		return fmt.Errorf("cannot authenticate SCRAM-SHA-256 server signature: %q", fields[0][2:])
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) saltPassword(salt []byte, iterCount int) {
 | 
			
		||||
	mac := hmac.New(c.newHash, []byte(c.pass))
 | 
			
		||||
	mac.Write(salt)
 | 
			
		||||
	mac.Write([]byte{0, 0, 0, 1})
 | 
			
		||||
	ui := mac.Sum(nil)
 | 
			
		||||
	hi := make([]byte, len(ui))
 | 
			
		||||
	copy(hi, ui)
 | 
			
		||||
	for i := 1; i < iterCount; i++ {
 | 
			
		||||
		mac.Reset()
 | 
			
		||||
		mac.Write(ui)
 | 
			
		||||
		mac.Sum(ui[:0])
 | 
			
		||||
		for j, b := range ui {
 | 
			
		||||
			hi[j] ^= b
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	c.saltedPass = hi
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) clientProof() []byte {
 | 
			
		||||
	mac := hmac.New(c.newHash, c.saltedPass)
 | 
			
		||||
	mac.Write([]byte("Client Key"))
 | 
			
		||||
	clientKey := mac.Sum(nil)
 | 
			
		||||
	hash := c.newHash()
 | 
			
		||||
	hash.Write(clientKey)
 | 
			
		||||
	storedKey := hash.Sum(nil)
 | 
			
		||||
	mac = hmac.New(c.newHash, storedKey)
 | 
			
		||||
	mac.Write(c.authMsg.Bytes())
 | 
			
		||||
	clientProof := mac.Sum(nil)
 | 
			
		||||
	for i, b := range clientKey {
 | 
			
		||||
		clientProof[i] ^= b
 | 
			
		||||
	}
 | 
			
		||||
	clientProof64 := make([]byte, b64.EncodedLen(len(clientProof)))
 | 
			
		||||
	b64.Encode(clientProof64, clientProof)
 | 
			
		||||
	return clientProof64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) serverSignature() []byte {
 | 
			
		||||
	mac := hmac.New(c.newHash, c.saltedPass)
 | 
			
		||||
	mac.Write([]byte("Server Key"))
 | 
			
		||||
	serverKey := mac.Sum(nil)
 | 
			
		||||
 | 
			
		||||
	mac = hmac.New(c.newHash, serverKey)
 | 
			
		||||
	mac.Write(c.authMsg.Bytes())
 | 
			
		||||
	serverSignature := mac.Sum(nil)
 | 
			
		||||
 | 
			
		||||
	encoded := make([]byte, b64.EncodedLen(len(serverSignature)))
 | 
			
		||||
	b64.Encode(encoded, serverSignature)
 | 
			
		||||
	return encoded
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -58,7 +58,13 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
 | 
			
		|||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	sslRenegotiation(&tlsConf)
 | 
			
		||||
 | 
			
		||||
	// Accept renegotiation requests initiated by the backend.
 | 
			
		||||
	//
 | 
			
		||||
	// Renegotiation was deprecated then removed from PostgreSQL 9.5, but
 | 
			
		||||
	// the default configuration of older versions has it enabled. Redshift
 | 
			
		||||
	// also initiates renegotiations and cannot be reconfigured.
 | 
			
		||||
	tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient
 | 
			
		||||
 | 
			
		||||
	return func(conn net.Conn) (net.Conn, error) {
 | 
			
		||||
		client := tls.Client(conn, &tlsConf)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,14 +0,0 @@
 | 
			
		|||
// +build go1.7
 | 
			
		||||
 | 
			
		||||
package pq
 | 
			
		||||
 | 
			
		||||
import "crypto/tls"
 | 
			
		||||
 | 
			
		||||
// Accept renegotiation requests initiated by the backend.
 | 
			
		||||
//
 | 
			
		||||
// Renegotiation was deprecated then removed from PostgreSQL 9.5, but
 | 
			
		||||
// the default configuration of older versions has it enabled. Redshift
 | 
			
		||||
// also initiates renegotiations and cannot be reconfigured.
 | 
			
		||||
func sslRenegotiation(conf *tls.Config) {
 | 
			
		||||
	conf.Renegotiation = tls.RenegotiateFreelyAsClient
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -1,8 +0,0 @@
 | 
			
		|||
// +build !go1.7
 | 
			
		||||
 | 
			
		||||
package pq
 | 
			
		||||
 | 
			
		||||
import "crypto/tls"
 | 
			
		||||
 | 
			
		||||
// Renegotiation is not supported by crypto/tls until Go 1.7.
 | 
			
		||||
func sslRenegotiation(*tls.Config) {}
 | 
			
		||||
| 
						 | 
				
			
			@ -209,9 +209,10 @@ github.com/klauspost/crc32
 | 
			
		|||
# github.com/lafriks/xormstore v1.0.0
 | 
			
		||||
github.com/lafriks/xormstore
 | 
			
		||||
github.com/lafriks/xormstore/util
 | 
			
		||||
# github.com/lib/pq v1.0.0
 | 
			
		||||
# github.com/lib/pq v1.1.0
 | 
			
		||||
github.com/lib/pq
 | 
			
		||||
github.com/lib/pq/oid
 | 
			
		||||
github.com/lib/pq/scram
 | 
			
		||||
# github.com/lunny/dingtalk_webhook v0.0.0-20171025031554-e3534c89ef96
 | 
			
		||||
github.com/lunny/dingtalk_webhook
 | 
			
		||||
# github.com/lunny/levelqueue v0.0.0-20190217115915-02b525a4418e
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue