diff --git a/CHANGELOG.md b/CHANGELOG.md index cec9d35f1..b49105d21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. - Support pagination (#246) - A Makefile target to test with race detector (#218) - Support CRUD API (#108) +- An ability to replace a base network connection to a Tarantool + instance (#265) ### Changed diff --git a/connection.go b/connection.go index efdd09e65..7da3f26d0 100644 --- a/connection.go +++ b/connection.go @@ -3,8 +3,6 @@ package tarantool import ( - "bufio" - "bytes" "context" "encoding/binary" "errors" @@ -12,9 +10,7 @@ import ( "io" "log" "math" - "net" "runtime" - "strings" "sync" "sync/atomic" "time" @@ -29,11 +25,6 @@ const ( connClosed = 3 ) -const ( - connTransportNone = "" - connTransportSsl = "ssl" -) - const shutdownEventKey = "box.shutdown" type ConnEventKind int @@ -149,7 +140,7 @@ func (d defaultLogger) Report(event ConnLogKind, conn *Connection, v ...interfac // More on graceful shutdown: https://www.tarantool.io/en/doc/latest/dev_guide/internals/iproto/graceful_shutdown/ type Connection struct { addr string - c net.Conn + c Conn mutex sync.Mutex cond *sync.Cond // Schema contains schema loaded on connection. @@ -237,16 +228,13 @@ type connShard struct { enc *encoder } -// Greeting is a message sent by Tarantool on connect. -type Greeting struct { - Version string - auth string -} - // Opts is a way to configure Connection type Opts struct { // Auth is an authentication method. Auth Auth + // Dialer is a Dialer object used to create a new connection to a + // Tarantool instance. TtDialer is a default one. + Dialer Dialer // Timeout for response to a particular request. The timeout is reset when // push messages are received. If Timeout is zero, any request can be // blocked infinitely. @@ -377,6 +365,9 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) { if conn.opts.Concurrency == 0 || conn.opts.Concurrency > maxprocs*128 { conn.opts.Concurrency = maxprocs * 4 } + if conn.opts.Dialer == nil { + conn.opts.Dialer = TtDialer{} + } if c := conn.opts.Concurrency; c&(c-1) != 0 { for i := uint(1); i < 32; i *= 2 { c |= c >> i @@ -504,118 +495,43 @@ func (conn *Connection) cancelFuture(fut *Future, err error) { } func (conn *Connection) dial() (err error) { - var connection net.Conn - network := "tcp" opts := conn.opts - address := conn.addr - timeout := opts.Reconnect / 2 - transport := opts.Transport - if timeout == 0 { - timeout = 500 * time.Millisecond - } else if timeout > 5*time.Second { - timeout = 5 * time.Second - } - // Unix socket connection - addrLen := len(address) - if addrLen > 0 && (address[0] == '.' || address[0] == '/') { - network = "unix" - } else if addrLen >= 7 && address[0:7] == "unix://" { - network = "unix" - address = address[7:] - } else if addrLen >= 5 && address[0:5] == "unix:" { - network = "unix" - address = address[5:] - } else if addrLen >= 6 && address[0:6] == "unix/:" { - network = "unix" - address = address[6:] - } else if addrLen >= 6 && address[0:6] == "tcp://" { - address = address[6:] - } else if addrLen >= 4 && address[0:4] == "tcp:" { - address = address[4:] - } - if transport == connTransportNone { - connection, err = net.DialTimeout(network, address, timeout) - } else if transport == connTransportSsl { - connection, err = sslDialTimeout(network, address, timeout, opts.Ssl) - } else { - err = errors.New("An unsupported transport type: " + transport) - } - if err != nil { - return - } - dc := &DeadlineIO{to: opts.Timeout, c: connection} - r := bufio.NewReaderSize(dc, 128*1024) - w := bufio.NewWriterSize(dc, 128*1024) - greeting := make([]byte, 128) - _, err = io.ReadFull(r, greeting) + dialTimeout := opts.Reconnect / 2 + if dialTimeout == 0 { + dialTimeout = 500 * time.Millisecond + } else if dialTimeout > 5*time.Second { + dialTimeout = 5 * time.Second + } + + var c Conn + c, err = conn.opts.Dialer.Dial(conn.addr, DialOpts{ + DialTimeout: dialTimeout, + IoTimeout: opts.Timeout, + Transport: opts.Transport, + Ssl: opts.Ssl, + RequiredProtocol: opts.RequiredProtocolInfo, + Auth: opts.Auth, + User: opts.User, + Password: opts.Pass, + }) if err != nil { - connection.Close() return } - conn.Greeting.Version = bytes.NewBuffer(greeting[:64]).String() - conn.Greeting.auth = bytes.NewBuffer(greeting[64:108]).String() - - // IPROTO_ID requests can be processed without authentication. - // https://www.tarantool.io/en/doc/latest/dev_guide/internals/iproto/requests/#iproto-id - if err = conn.identify(w, r); err != nil { - connection.Close() - return err - } - - if err = checkProtocolInfo(opts.RequiredProtocolInfo, conn.serverProtocolInfo); err != nil { - connection.Close() - return fmt.Errorf("identify: %w", err) - } - - // Auth. - if opts.User != "" { - auth := opts.Auth - if opts.Auth == AutoAuth { - if conn.serverProtocolInfo.Auth != AutoAuth { - auth = conn.serverProtocolInfo.Auth - } else { - auth = ChapSha1Auth - } - } - - var req Request - if auth == ChapSha1Auth { - salt := conn.Greeting.auth - req, err = newChapSha1AuthRequest(conn.opts.User, salt, opts.Pass) - if err != nil { - return fmt.Errorf("auth: %w", err) - } - } else if auth == PapSha256Auth { - if opts.Transport != connTransportSsl { - return errors.New("auth: forbidden to use " + auth.String() + - " unless SSL is enabled for the connection") - } - req = newPapSha256AuthRequest(conn.opts.User, opts.Pass) - } else { - connection.Close() - return errors.New("auth: " + auth.String()) - } - if err = conn.writeRequest(w, req); err != nil { - connection.Close() - return fmt.Errorf("auth: %w", err) - } - if _, err = conn.readResponse(r); err != nil { - connection.Close() - return fmt.Errorf("auth: %w", err) - } - } + conn.Greeting.Version = c.Greeting().Version + conn.serverProtocolInfo = c.ProtocolInfo() // Watchers. conn.watchMap.Range(func(key, value interface{}) bool { st := value.(chan watchState) state := <-st if state.unready != nil { + st <- state return true } req := newWatchRequest(key.(string)) - if err = conn.writeRequest(w, req); err != nil { + if err = writeRequest(c, req); err != nil { st <- state return false } @@ -626,17 +542,18 @@ func (conn *Connection) dial() (err error) { }) if err != nil { + c.Close() return fmt.Errorf("unable to register watch: %w", err) } // Only if connected and fully initialized. conn.lockShards() - conn.c = connection + conn.c = c atomic.StoreUint32(&conn.state, connConnected) conn.cond.Broadcast() conn.unlockShards() - go conn.writer(w, connection) - go conn.reader(r, connection) + go conn.writer(c, c) + go conn.reader(c, c) // Subscribe shutdown event to process graceful shutdown. if conn.shutdownWatcher == nil && isFeatureInSlice(WatchersFeature, conn.serverProtocolInfo.Features) { @@ -700,45 +617,6 @@ func pack(h *smallWBuf, enc *encoder, reqid uint32, return } -func (conn *Connection) writeRequest(w *bufio.Writer, req Request) error { - var packet smallWBuf - err := pack(&packet, newEncoder(&packet), 0, req, ignoreStreamId, nil) - - if err != nil { - return fmt.Errorf("pack error: %w", err) - } - if err = write(w, packet.b); err != nil { - return fmt.Errorf("write error: %w", err) - } - if err = w.Flush(); err != nil { - return fmt.Errorf("flush error: %w", err) - } - return err -} - -func (conn *Connection) readResponse(r io.Reader) (Response, error) { - respBytes, err := conn.read(r) - if err != nil { - return Response{}, fmt.Errorf("read error: %w", err) - } - - resp := Response{buf: smallBuf{b: respBytes}} - err = resp.decodeHeader(conn.dec) - if err != nil { - return resp, fmt.Errorf("decode response header error: %w", err) - } - err = resp.decodeBody() - if err != nil { - switch err.(type) { - case Error: - return resp, err - default: - return resp, fmt.Errorf("decode response body error: %w", err) - } - } - return resp, nil -} - func (conn *Connection) createConnection(reconnect bool) (err error) { var reconnects uint for conn.c == nil && conn.state == connDisconnected { @@ -805,7 +683,7 @@ func (conn *Connection) closeConnection(neterr error, forever bool) (err error) return } -func (conn *Connection) reconnectImpl(neterr error, c net.Conn) { +func (conn *Connection) reconnectImpl(neterr error, c Conn) { if conn.opts.Reconnect > 0 { if c == conn.c { conn.closeConnection(neterr, false) @@ -818,7 +696,7 @@ func (conn *Connection) reconnectImpl(neterr error, c net.Conn) { } } -func (conn *Connection) reconnect(neterr error, c net.Conn) { +func (conn *Connection) reconnect(neterr error, c Conn) { conn.mutex.Lock() defer conn.mutex.Unlock() conn.reconnectImpl(neterr, c) @@ -865,7 +743,7 @@ func (conn *Connection) notify(kind ConnEventKind) { } } -func (conn *Connection) writer(w *bufio.Writer, c net.Conn) { +func (conn *Connection) writer(w writeFlusher, c Conn) { var shardn uint32 var packet smallWBuf for atomic.LoadUint32(&conn.state) != connClosed { @@ -897,7 +775,7 @@ func (conn *Connection) writer(w *bufio.Writer, c net.Conn) { if packet.Len() == 0 { continue } - if err := write(w, packet.b); err != nil { + if _, err := w.Write(packet.b); err != nil { conn.reconnect(err, c) return } @@ -945,14 +823,14 @@ func readWatchEvent(reader io.Reader) (connWatchEvent, error) { return event, nil } -func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { +func (conn *Connection) reader(r io.Reader, c Conn) { events := make(chan connWatchEvent, 1024) defer close(events) go conn.eventer(events) for atomic.LoadUint32(&conn.state) != connClosed { - respBytes, err := conn.read(r) + respBytes, err := read(r, conn.lenbuf[:]) if err != nil { conn.reconnect(err, c) return @@ -1299,31 +1177,20 @@ func (conn *Connection) timeouts() { } } -func write(w io.Writer, data []byte) (err error) { - l, err := w.Write(data) - if err != nil { - return - } - if l != len(data) { - panic("Wrong length writed") - } - return -} - -func (conn *Connection) read(r io.Reader) (response []byte, err error) { +func read(r io.Reader, lenbuf []byte) (response []byte, err error) { var length int - if _, err = io.ReadFull(r, conn.lenbuf[:]); err != nil { + if _, err = io.ReadFull(r, lenbuf); err != nil { return } - if conn.lenbuf[0] != 0xce { + if lenbuf[0] != 0xce { err = errors.New("Wrong response header") return } - length = (int(conn.lenbuf[1]) << 24) + - (int(conn.lenbuf[2]) << 16) + - (int(conn.lenbuf[3]) << 8) + - int(conn.lenbuf[4]) + length = (int(lenbuf[1]) << 24) + + (int(lenbuf[2]) << 16) + + (int(lenbuf[3]) << 8) + + int(lenbuf[4]) if length == 0 { err = errors.New("Response should not be 0 length") @@ -1629,78 +1496,6 @@ func (conn *Connection) newWatcherImpl(key string, callback WatchCallback) (Watc }, nil } -// checkProtocolInfo checks that expected protocol version is -// and protocol features are supported. -func checkProtocolInfo(expected ProtocolInfo, actual ProtocolInfo) error { - var found bool - var missingFeatures []ProtocolFeature - - if expected.Version > actual.Version { - return fmt.Errorf("protocol version %d is not supported", expected.Version) - } - - // It seems that iterating over a small list is way faster - // than building a map: https://stackoverflow.com/a/52710077/11646599 - for _, expectedFeature := range expected.Features { - found = false - for _, actualFeature := range actual.Features { - if expectedFeature == actualFeature { - found = true - } - } - if !found { - missingFeatures = append(missingFeatures, expectedFeature) - } - } - - if len(missingFeatures) == 1 { - return fmt.Errorf("protocol feature %s is not supported", missingFeatures[0]) - } - - if len(missingFeatures) > 1 { - var sarr []string - for _, missingFeature := range missingFeatures { - sarr = append(sarr, missingFeature.String()) - } - return fmt.Errorf("protocol features %s are not supported", strings.Join(sarr, ", ")) - } - - return nil -} - -// identify sends info about client protocol, receives info -// about server protocol in response and stores it in the connection. -func (conn *Connection) identify(w *bufio.Writer, r *bufio.Reader) error { - var ok bool - - req := NewIdRequest(clientProtocolInfo) - werr := conn.writeRequest(w, req) - if werr != nil { - return fmt.Errorf("identify: %w", werr) - } - - resp, rerr := conn.readResponse(r) - if rerr != nil { - if resp.Code == ErrUnknownRequestType { - // IPROTO_ID requests are not supported by server. - return nil - } - - return fmt.Errorf("identify: %w", rerr) - } - - if len(resp.Data) == 0 { - return fmt.Errorf("identify: unexpected response: no data") - } - - conn.serverProtocolInfo, ok = resp.Data[0].(ProtocolInfo) - if !ok { - return fmt.Errorf("identify: unexpected response: wrong data") - } - - return nil -} - // ServerProtocolVersion returns protocol version and protocol features // supported by connected Tarantool server. Beware that values might be // outdated if connection is in a disconnected state. diff --git a/dial.go b/dial.go new file mode 100644 index 000000000..abed85e1b --- /dev/null +++ b/dial.go @@ -0,0 +1,392 @@ +package tarantool + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "net" + "strings" + "time" +) + +const ( + dialTransportNone = "" + dialTransportSsl = "ssl" +) + +// Greeting is a message sent by Tarantool on connect. +type Greeting struct { + Version string +} + +// writeFlusher is the interface that groups the basic Write and Flush methods. +type writeFlusher interface { + io.Writer + Flush() error +} + +// Conn is a generic stream-oriented network connection to a Tarantool +// instance. +type Conn interface { + // Read reads data from the connection. + Read(b []byte) (int, error) + // Write writes data to the connection. There may be an internal buffer for + // better performance control from a client side. + Write(b []byte) (int, error) + // Flush writes any buffered data. + Flush() error + // Close closes the connection. + // Any blocked Read or Flush operations will be unblocked and return + // errors. + Close() error + // LocalAddr returns the local network address, if known. + LocalAddr() net.Addr + // RemoteAddr returns the remote network address, if known. + RemoteAddr() net.Addr + // Greeting returns server greeting. + Greeting() Greeting + // ProtocolInfo returns server protocol info. + ProtocolInfo() ProtocolInfo +} + +// DialOpts is a way to configure a Dial method to create a new Conn. +type DialOpts struct { + // DialTimeout is a timeout for an initial network dial. + DialTimeout time.Duration + // IoTimeout is a timeout per a network read/write. + IoTimeout time.Duration + // Transport is a connect transport type. + Transport string + // Ssl configures "ssl" transport. + Ssl SslOpts + // RequiredProtocol contains minimal protocol version and + // list of protocol features that should be supported by + // Tarantool server. By default there are no restrictions. + RequiredProtocol ProtocolInfo + // Auth is an authentication method. + Auth Auth + // Username for logging in to Tarantool. + User string + // User password for logging in to Tarantool. + Password string +} + +// Dialer is the interface that wraps a method to connect to a Tarantool +// instance. The main idea is to provide a ready-to-work connection with +// basic preparation, successful authorization and additional checks. +// +// You can provide your own implementation to Connect() call via Opts.Dialer if +// some functionality is not implemented in the connector. See TtDialer.Dial() +// implementation as example. +type Dialer interface { + // Dial connects to a Tarantool instance to the address with specified + // options. + Dial(address string, opts DialOpts) (Conn, error) +} + +type tntConn struct { + net net.Conn + reader io.Reader + writer writeFlusher + greeting Greeting + protocol ProtocolInfo +} + +// TtDialer is a default implementation of the Dialer interface which is +// used by the connector. +type TtDialer struct { +} + +// Dial connects to a Tarantool instance to the address with specified +// options. +func (t TtDialer) Dial(address string, opts DialOpts) (Conn, error) { + var err error + conn := new(tntConn) + + if conn.net, err = dial(address, opts); err != nil { + return nil, fmt.Errorf("failed to dial: %w", err) + } + + dc := &DeadlineIO{to: opts.IoTimeout, c: conn.net} + conn.reader = bufio.NewReaderSize(dc, 128*1024) + conn.writer = bufio.NewWriterSize(dc, 128*1024) + + var version, salt string + if version, salt, err = readGreeting(conn.reader); err != nil { + conn.net.Close() + return nil, fmt.Errorf("failed to read greeting: %w", err) + } + conn.greeting.Version = version + + if conn.protocol, err = identify(conn.writer, conn.reader); err != nil { + conn.net.Close() + return nil, fmt.Errorf("failed to identify: %w", err) + } + + if err = checkProtocolInfo(opts.RequiredProtocol, conn.protocol); err != nil { + conn.net.Close() + return nil, fmt.Errorf("invalid server protocol: %w", err) + } + + if opts.User != "" { + if opts.Auth == AutoAuth { + if conn.protocol.Auth != AutoAuth { + opts.Auth = conn.protocol.Auth + } else { + opts.Auth = ChapSha1Auth + } + } + + err := authenticate(conn, opts, salt) + if err != nil { + conn.net.Close() + return nil, fmt.Errorf("failed to authenticate: %w", err) + } + } + + return conn, nil +} + +// Read makes tntConn satisfy the Conn interface. +func (c *tntConn) Read(p []byte) (int, error) { + return c.reader.Read(p) +} + +// Write makes tntConn satisfy the Conn interface. +func (c *tntConn) Write(p []byte) (int, error) { + if l, err := c.writer.Write(p); err != nil { + return l, err + } else if l != len(p) { + return l, errors.New("wrong length written") + } else { + return l, nil + } +} + +// Flush makes tntConn satisfy the Conn interface. +func (c *tntConn) Flush() error { + return c.writer.Flush() +} + +// Close makes tntConn satisfy the Conn interface. +func (c *tntConn) Close() error { + return c.net.Close() +} + +// RemoteAddr makes tntConn satisfy the Conn interface. +func (c *tntConn) RemoteAddr() net.Addr { + return c.net.RemoteAddr() +} + +// LocalAddr makes tntConn satisfy the Conn interface. +func (c *tntConn) LocalAddr() net.Addr { + return c.net.LocalAddr() +} + +// Greeting makes tntConn satisfy the Conn interface. +func (c *tntConn) Greeting() Greeting { + return c.greeting +} + +// ProtocolInfo makes tntConn satisfy the Conn interface. +func (c *tntConn) ProtocolInfo() ProtocolInfo { + return c.protocol +} + +// dial connects to a Tarantool instance. +func dial(address string, opts DialOpts) (net.Conn, error) { + network, address := parseAddress(address) + switch opts.Transport { + case dialTransportNone: + return net.DialTimeout(network, address, opts.DialTimeout) + case dialTransportSsl: + return sslDialTimeout(network, address, opts.DialTimeout, opts.Ssl) + default: + return nil, fmt.Errorf("unsupported transport type: %s", opts.Transport) + } +} + +// parseAddress split address into network and address parts. +func parseAddress(address string) (string, string) { + network := "tcp" + addrLen := len(address) + + if addrLen > 0 && (address[0] == '.' || address[0] == '/') { + network = "unix" + } else if addrLen >= 7 && address[0:7] == "unix://" { + network = "unix" + address = address[7:] + } else if addrLen >= 5 && address[0:5] == "unix:" { + network = "unix" + address = address[5:] + } else if addrLen >= 6 && address[0:6] == "unix/:" { + network = "unix" + address = address[6:] + } else if addrLen >= 6 && address[0:6] == "tcp://" { + address = address[6:] + } else if addrLen >= 4 && address[0:4] == "tcp:" { + address = address[4:] + } + + return network, address +} + +// readGreeting reads a greeting message. +func readGreeting(reader io.Reader) (string, string, error) { + var version, salt string + + data := make([]byte, 128) + _, err := io.ReadFull(reader, data) + if err == nil { + version = bytes.NewBuffer(data[:64]).String() + salt = bytes.NewBuffer(data[64:108]).String() + } + + return version, salt, err +} + +// identify sends info about client protocol, receives info +// about server protocol in response and stores it in the connection. +func identify(w writeFlusher, r io.Reader) (ProtocolInfo, error) { + var info ProtocolInfo + + req := NewIdRequest(clientProtocolInfo) + if err := writeRequest(w, req); err != nil { + return info, err + } + + resp, err := readResponse(r) + if err != nil { + if resp.Code == ErrUnknownRequestType { + // IPROTO_ID requests are not supported by server. + return info, nil + } + + return info, err + } + + if len(resp.Data) == 0 { + return info, errors.New("unexpected response: no data") + } + + info, ok := resp.Data[0].(ProtocolInfo) + if !ok { + return info, errors.New("unexpected response: wrong data") + } + + return info, nil +} + +// checkProtocolInfo checks that required protocol version is +// and protocol features are supported. +func checkProtocolInfo(required ProtocolInfo, actual ProtocolInfo) error { + if required.Version > actual.Version { + return fmt.Errorf("protocol version %d is not supported", + required.Version) + } + + // It seems that iterating over a small list is way faster + // than building a map: https://stackoverflow.com/a/52710077/11646599 + var missed []string + for _, requiredFeature := range required.Features { + found := false + for _, actualFeature := range actual.Features { + if requiredFeature == actualFeature { + found = true + } + } + if !found { + missed = append(missed, requiredFeature.String()) + } + } + + switch { + case len(missed) == 1: + return fmt.Errorf("protocol feature %s is not supported", missed[0]) + case len(missed) > 1: + joined := strings.Join(missed, ", ") + return fmt.Errorf("protocol features %s are not supported", joined) + default: + return nil + } +} + +// authenticate authenticate for a connection. +func authenticate(c Conn, opts DialOpts, salt string) error { + auth := opts.Auth + user := opts.User + pass := opts.Password + + var req Request + var err error + + switch opts.Auth { + case ChapSha1Auth: + req, err = newChapSha1AuthRequest(user, pass, salt) + if err != nil { + return err + } + case PapSha256Auth: + if opts.Transport != dialTransportSsl { + return errors.New("forbidden to use " + auth.String() + + " unless SSL is enabled for the connection") + } + req = newPapSha256AuthRequest(user, pass) + default: + return errors.New("unsupported method " + opts.Auth.String()) + } + + if err = writeRequest(c, req); err != nil { + return err + } + if _, err = readResponse(c); err != nil { + return err + } + return nil +} + +// writeRequest writes a request to the writer. +func writeRequest(w writeFlusher, req Request) error { + var packet smallWBuf + err := pack(&packet, newEncoder(&packet), 0, req, ignoreStreamId, nil) + + if err != nil { + return fmt.Errorf("pack error: %w", err) + } + if _, err = w.Write(packet.b); err != nil { + return fmt.Errorf("write error: %w", err) + } + if err = w.Flush(); err != nil { + return fmt.Errorf("flush error: %w", err) + } + return err +} + +// readResponse reads a response from the reader. +func readResponse(r io.Reader) (Response, error) { + var lenbuf [PacketLengthBytes]byte + + respBytes, err := read(r, lenbuf[:]) + if err != nil { + return Response{}, fmt.Errorf("read error: %w", err) + } + + resp := Response{buf: smallBuf{b: respBytes}} + err = resp.decodeHeader(newDecoder(&smallBuf{})) + if err != nil { + return resp, fmt.Errorf("decode response header error: %w", err) + } + + err = resp.decodeBody() + if err != nil { + switch err.(type) { + case Error: + return resp, err + default: + return resp, fmt.Errorf("decode response body error: %w", err) + } + } + return resp, nil +} diff --git a/dial_test.go b/dial_test.go new file mode 100644 index 000000000..182e9c866 --- /dev/null +++ b/dial_test.go @@ -0,0 +1,340 @@ +package tarantool_test + +import ( + "bytes" + "errors" + "net" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/tarantool/go-tarantool" +) + +type mockErrorDialer struct { + err error +} + +func (m mockErrorDialer) Dial(address string, + opts tarantool.DialOpts) (tarantool.Conn, error) { + return nil, m.err +} + +func TestDialer_Dial_error(t *testing.T) { + const errMsg = "any msg" + dialer := mockErrorDialer{ + err: errors.New(errMsg), + } + + conn, err := tarantool.Connect("any", tarantool.Opts{ + Dialer: dialer, + }) + assert.Nil(t, conn) + assert.ErrorContains(t, err, errMsg) +} + +type mockPassedDialer struct { + address string + opts tarantool.DialOpts +} + +func (m *mockPassedDialer) Dial(address string, + opts tarantool.DialOpts) (tarantool.Conn, error) { + m.address = address + m.opts = opts + return nil, errors.New("does not matter") +} + +func TestDialer_Dial_passedOpts(t *testing.T) { + const addr = "127.0.0.1:8080" + opts := tarantool.DialOpts{ + DialTimeout: 500 * time.Millisecond, + IoTimeout: 2, + Transport: "any", + Ssl: tarantool.SslOpts{ + KeyFile: "a", + CertFile: "b", + CaFile: "c", + Ciphers: "d", + }, + RequiredProtocol: tarantool.ProtocolInfo{ + Auth: tarantool.ChapSha1Auth, + Version: 33, + Features: []tarantool.ProtocolFeature{ + tarantool.ErrorExtensionFeature, + }, + }, + Auth: tarantool.ChapSha1Auth, + User: "user", + Password: "password", + } + + dialer := &mockPassedDialer{} + conn, err := tarantool.Connect(addr, tarantool.Opts{ + Dialer: dialer, + Timeout: opts.IoTimeout, + Transport: opts.Transport, + Ssl: opts.Ssl, + Auth: opts.Auth, + User: opts.User, + Pass: opts.Password, + RequiredProtocolInfo: opts.RequiredProtocol, + }) + + assert.Nil(t, conn) + assert.NotNil(t, err) + assert.Equal(t, addr, dialer.address) + assert.Equal(t, opts, dialer.opts) +} + +type mockIoConn struct { + // Sends an event on Read()/Write()/Flush(). + read, written chan struct{} + // Read()/Write() buffers. + readbuf, writebuf bytes.Buffer + // Calls readWg/writeWg.Wait() in Read()/Flush(). + readWg, writeWg sync.WaitGroup + // How many times to wait before a wg.Wait() call. + readWgDelay, writeWgDelay int + // Write()/Read()/Flush()/Close() calls count. + writeCnt, readCnt, flushCnt, closeCnt int + // LocalAddr()/RemoteAddr() calls count. + localCnt, remoteCnt int + // Greeting()/ProtocolInfo() calls count. + greetingCnt, infoCnt int + // Values for LocalAddr()/RemoteAddr(). + local, remote net.Addr + // Value for Greeting(). + greeting tarantool.Greeting + // Value for ProtocolInfo(). + info tarantool.ProtocolInfo +} + +func (m *mockIoConn) Read(b []byte) (int, error) { + m.readCnt++ + if m.readWgDelay == 0 { + m.readWg.Wait() + } + m.readWgDelay-- + + ret, err := m.readbuf.Read(b) + + if m.read != nil { + m.read <- struct{}{} + } + + return ret, err +} + +func (m *mockIoConn) Write(b []byte) (int, error) { + m.writeCnt++ + if m.writeWgDelay == 0 { + m.writeWg.Wait() + } + m.writeWgDelay-- + + ret, err := m.writebuf.Write(b) + + if m.written != nil { + m.written <- struct{}{} + } + + return ret, err +} + +func (m *mockIoConn) Flush() error { + m.flushCnt++ + return nil +} + +func (m *mockIoConn) Close() error { + m.closeCnt++ + return nil +} + +func (m *mockIoConn) LocalAddr() net.Addr { + m.localCnt++ + return m.local +} + +func (m *mockIoConn) RemoteAddr() net.Addr { + m.remoteCnt++ + return m.remote +} + +func (m *mockIoConn) Greeting() tarantool.Greeting { + m.greetingCnt++ + return m.greeting +} + +func (m *mockIoConn) ProtocolInfo() tarantool.ProtocolInfo { + m.infoCnt++ + return m.info +} + +type mockIoDialer struct { + init func(conn *mockIoConn) + conn *mockIoConn +} + +func newMockIoConn() *mockIoConn { + conn := new(mockIoConn) + conn.readWg.Add(1) + conn.writeWg.Add(1) + return conn +} + +func (m *mockIoDialer) Dial(address string, + opts tarantool.DialOpts) (tarantool.Conn, error) { + m.conn = newMockIoConn() + if m.init != nil { + m.init(m.conn) + } + return m.conn, nil +} + +func dialIo(t *testing.T, + init func(conn *mockIoConn)) (*tarantool.Connection, mockIoDialer) { + t.Helper() + + dialer := mockIoDialer{ + init: init, + } + conn, err := tarantool.Connect("any", tarantool.Opts{ + Dialer: &dialer, + Timeout: 1000 * time.Second, // Avoid pings. + SkipSchema: true, + }) + require.Nil(t, err) + require.NotNil(t, conn) + + return conn, dialer +} + +func TestConn_Close(t *testing.T) { + conn, dialer := dialIo(t, nil) + conn.Close() + + assert.Equal(t, 1, dialer.conn.closeCnt) + + dialer.conn.readWg.Done() + dialer.conn.writeWg.Done() +} + +type stubAddr struct { + net.Addr + str string +} + +func (a stubAddr) String() string { + return a.str +} + +func TestConn_LocalAddr(t *testing.T) { + const addr = "any" + conn, dialer := dialIo(t, func(conn *mockIoConn) { + conn.local = stubAddr{str: addr} + }) + defer func() { + dialer.conn.readWg.Done() + dialer.conn.writeWg.Done() + conn.Close() + }() + + assert.Equal(t, addr, conn.LocalAddr()) + assert.Equal(t, 1, dialer.conn.localCnt) +} + +func TestConn_RemoteAddr(t *testing.T) { + const addr = "any" + conn, dialer := dialIo(t, func(conn *mockIoConn) { + conn.remote = stubAddr{str: addr} + }) + defer func() { + dialer.conn.readWg.Done() + dialer.conn.writeWg.Done() + conn.Close() + }() + + assert.Equal(t, addr, conn.RemoteAddr()) + assert.Equal(t, 1, dialer.conn.remoteCnt) +} + +func TestConn_Greeting(t *testing.T) { + greeting := tarantool.Greeting{ + Version: "any", + } + conn, dialer := dialIo(t, func(conn *mockIoConn) { + conn.greeting = greeting + }) + defer func() { + dialer.conn.readWg.Done() + dialer.conn.writeWg.Done() + conn.Close() + }() + + assert.Equal(t, &greeting, conn.Greeting) + assert.Equal(t, 1, dialer.conn.greetingCnt) +} + +func TestConn_ProtocolInfo(t *testing.T) { + info := tarantool.ProtocolInfo{ + Auth: tarantool.ChapSha1Auth, + Version: 33, + Features: []tarantool.ProtocolFeature{ + tarantool.ErrorExtensionFeature, + }, + } + conn, dialer := dialIo(t, func(conn *mockIoConn) { + conn.info = info + }) + defer func() { + dialer.conn.readWg.Done() + dialer.conn.writeWg.Done() + conn.Close() + }() + + assert.Equal(t, info, conn.ServerProtocolInfo()) + assert.Equal(t, 1, dialer.conn.infoCnt) +} + +func TestConn_ReadWrite(t *testing.T) { + conn, dialer := dialIo(t, func(conn *mockIoConn) { + conn.read = make(chan struct{}) + conn.written = make(chan struct{}) + conn.writeWgDelay = 1 + conn.readbuf.Write([]byte{ + 0xce, 0x00, 0x00, 0x00, 0x0a, // Length. + 0x82, // Header map. + 0x00, 0x00, + 0x01, 0xce, 0x00, 0x00, 0x00, 0x02, + 0x80, // Body map. + }) + conn.Close() + }) + defer func() { + dialer.conn.writeWg.Done() + }() + + fut := conn.Do(tarantool.NewPingRequest()) + + <-dialer.conn.written + dialer.conn.readWg.Done() + <-dialer.conn.read + <-dialer.conn.read + + assert.Equal(t, []byte{ + 0xce, 0x00, 0x00, 0x00, 0xa, // Length. + 0x82, // Header map. + 0x00, 0x40, + 0x01, 0xce, 0x00, 0x00, 0x00, 0x02, + 0x80, // Empty map. + }, dialer.conn.writebuf.Bytes()) + + resp, err := fut.Get() + assert.Nil(t, err) + assert.NotNil(t, resp) +} diff --git a/request.go b/request.go index 55e36292d..f6d3cc245 100644 --- a/request.go +++ b/request.go @@ -658,37 +658,54 @@ func (req *spaceIndexRequest) setIndex(index interface{}) { req.index = index } +// authRequest implements IPROTO_AUTH request. type authRequest struct { - baseRequest auth Auth user, pass string } -func newChapSha1AuthRequest(user, salt, password string) (*authRequest, error) { +// newChapSha1AuthRequest create a new authRequest with chap-sha1 +// authentication method. +func newChapSha1AuthRequest(user, password, salt string) (authRequest, error) { + req := authRequest{} scr, err := scramble(salt, password) if err != nil { - return nil, fmt.Errorf("scrambling failure: %w", err) + return req, fmt.Errorf("scrambling failure: %w", err) } - req := new(authRequest) - req.requestCode = AuthRequestCode req.auth = ChapSha1Auth req.user = user req.pass = string(scr) return req, nil } -func newPapSha256AuthRequest(user, password string) *authRequest { - req := new(authRequest) - req.requestCode = AuthRequestCode - req.auth = PapSha256Auth - req.user = user - req.pass = password - return req +// newPapSha256AuthRequest create a new authRequest with pap-sha256 +// authentication method. +func newPapSha256AuthRequest(user, password string) authRequest { + return authRequest{ + auth: PapSha256Auth, + user: user, + pass: password, + } +} + +// Code returns a IPROTO code for the request. +func (req authRequest) Code() int32 { + return AuthRequestCode +} + +// Async returns true if the request does not require a response. +func (req authRequest) Async() bool { + return false +} + +// Ctx returns a context of the request. +func (req authRequest) Ctx() context.Context { + return nil } // Body fills an encoder with the auth request body. -func (req *authRequest) Body(res SchemaResolver, enc *encoder) error { +func (req authRequest) Body(res SchemaResolver, enc *encoder) error { return enc.Encode(map[uint32]interface{}{ KeyUserName: req.user, KeyTuple: []interface{}{req.auth.String(), req.pass}, diff --git a/tarantool_test.go b/tarantool_test.go index f53e6b528..125642dcf 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -2,7 +2,9 @@ package tarantool_test import ( "context" + "encoding/binary" "fmt" + "io" "log" "math" "os" @@ -696,6 +698,60 @@ func BenchmarkSQLSerial(b *testing.B) { } } +func TestTtDialer(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + conn, err := TtDialer{}.Dial(server, DialOpts{}) + require.Nil(err) + require.NotNil(conn) + defer conn.Close() + + assert.Contains(conn.LocalAddr().String(), "127.0.0.1") + assert.Equal(server, conn.RemoteAddr().String()) + assert.NotEqual("", conn.Greeting().Version) + + // Write IPROTO_PING. + ping := []byte{ + 0xce, 0x00, 0x00, 0x00, 0xa, // Length. + 0x82, // Header map. + 0x00, 0x40, + 0x01, 0xce, 0x00, 0x00, 0x00, 0x02, + 0x80, // Empty map. + } + ret, err := conn.Write(ping) + require.Equal(len(ping), ret) + require.Nil(err) + require.Nil(conn.Flush()) + + // Read IPROTO_PING response length. + lenbuf := make([]byte, 5) + ret, err = io.ReadFull(conn, lenbuf) + require.Nil(err) + require.Equal(len(lenbuf), ret) + length := int(binary.BigEndian.Uint32(lenbuf[1:])) + require.Greater(length, 0) + + // Read IPROTO_PING response. + buf := make([]byte, length) + ret, err = io.ReadFull(conn, buf) + require.Nil(err) + require.Equal(len(buf), ret) + // Check that it is IPROTO_OK. + assert.Equal([]byte{0x83, 0x00, 0xce, 0x00, 0x00, 0x00, 0x00}, buf[:7]) +} + +func TestTtDialer_worksWithConnection(t *testing.T) { + defaultOpts := opts + defaultOpts.Dialer = TtDialer{} + + conn := test_helpers.ConnectWithValidation(t, server, defaultOpts) + defer conn.Close() + + _, err := conn.Do(NewPingRequest()).Get() + assert.Nil(t, err) +} + func TestOptsAuth_Default(t *testing.T) { defaultOpts := opts defaultOpts.Auth = AutoAuth @@ -722,8 +778,8 @@ func TestOptsAuth_PapSha256AuthForbit(t *testing.T) { conn.Close() } - if err.Error() != "auth: forbidden to use pap-sha256 unless "+ - "SSL is enabled for the connection" { + if err.Error() != "failed to authenticate: forbidden to use pap-sha256"+ + " unless SSL is enabled for the connection" { t.Errorf("An unexpected error: %s", err) } } @@ -3273,7 +3329,7 @@ func TestConnectionProtocolVersionRequirementFail(t *testing.T) { require.Nilf(t, conn, "Connect fail") require.NotNilf(t, err, "Got error on connect") - require.Contains(t, err.Error(), "identify: protocol version 3 is not supported") + require.Contains(t, err.Error(), "invalid server protocol: protocol version 3 is not supported") } func TestConnectionProtocolFeatureRequirementSuccess(t *testing.T) { @@ -3304,7 +3360,7 @@ func TestConnectionProtocolFeatureRequirementFail(t *testing.T) { require.Nilf(t, conn, "Connect fail") require.NotNilf(t, err, "Got error on connect") - require.Contains(t, err.Error(), "identify: protocol feature TransactionsFeature is not supported") + require.Contains(t, err.Error(), "invalid server protocol: protocol feature TransactionsFeature is not supported") } func TestConnectionProtocolFeatureRequirementManyFail(t *testing.T) { @@ -3321,7 +3377,7 @@ func TestConnectionProtocolFeatureRequirementManyFail(t *testing.T) { require.NotNilf(t, err, "Got error on connect") require.Contains(t, err.Error(), - "identify: protocol features TransactionsFeature, Unknown feature (code 15532) are not supported") + "invalid server protocol: protocol features TransactionsFeature, Unknown feature (code 15532) are not supported") } func TestConnectionFeatureOptsImmutable(t *testing.T) {