Skip to content

api: support iproto feature discovery #226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release.

### Added

- Support iproto feature discovery (#120).

### Changed

### Fixed
Expand Down
191 changes: 173 additions & 18 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"math"
"net"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -146,6 +147,8 @@ type Connection struct {
lenbuf [PacketLengthBytes]byte

lastStreamId uint64

serverProtocolInfo ProtocolInfo
}

var _ = Connector(&Connection{}) // Check compatibility with connector interface.
Expand Down Expand Up @@ -269,6 +272,10 @@ type Opts struct {
Transport string
// SslOpts is used only if the Transport == 'ssl' is set.
Ssl SslOpts
// RequiredProtocolInfo contains minimal protocol version and
// list of protocol features that should be supported by
// Tarantool server. By default there are no restrictions.
RequiredProtocolInfo ProtocolInfo
}

// SslOpts is a way to configure ssl transport.
Expand All @@ -293,6 +300,16 @@ type SslOpts struct {
Ciphers string
}

// Clone returns a copy of the Opts object.
// Any changes in copy RequiredProtocolInfo will not affect the original
// RequiredProtocolInfo value.
func (opts Opts) Clone() Opts {
optsCopy := opts
optsCopy.RequiredProtocolInfo = opts.RequiredProtocolInfo.Clone()

return optsCopy
}

// Connect creates and configures a new Connection.
//
// Address could be specified in following ways:
Expand All @@ -319,7 +336,7 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) {
contextRequestId: 1,
Greeting: &Greeting{},
control: make(chan struct{}),
opts: opts,
opts: opts.Clone(),
dec: newDecoder(&smallBuf{}),
}
maxprocs := uint32(runtime.GOMAXPROCS(-1))
Expand All @@ -344,9 +361,9 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) {
}
}

if opts.RateLimit > 0 {
conn.rlimit = make(chan struct{}, opts.RateLimit)
if opts.RLimitAction != RLimitDrop && opts.RLimitAction != RLimitWait {
if conn.opts.RateLimit > 0 {
conn.rlimit = make(chan struct{}, conn.opts.RateLimit)
if conn.opts.RLimitAction != RLimitDrop && conn.opts.RLimitAction != RLimitWait {
return nil, errors.New("RLimitAction should be specified to RLimitDone nor RLimitWait")
}
}
Expand Down Expand Up @@ -502,6 +519,18 @@ func (conn *Connection) dial() (err error) {
conn.Greeting.Version = bytes.NewBuffer(greeting[:64]).String()
conn.Greeting.auth = bytes.NewBuffer(greeting[64:108]).String()

// IPROTO_ID requests can be processed without authentication.
// https://www.tarantool.io/en/doc/latest/dev_guide/internals/iproto/requests/#iproto-id
if err = conn.identify(w, r); err != nil {
connection.Close()
return err
}

if err = checkProtocolInfo(opts.RequiredProtocolInfo, conn.serverProtocolInfo); err != nil {
connection.Close()
return fmt.Errorf("identify: %w", err)
}

// Auth
if opts.User != "" {
scr, err := scramble(conn.Greeting.auth, opts.Pass)
Expand Down Expand Up @@ -581,43 +610,83 @@ func pack(h *smallWBuf, enc *encoder, reqid uint32,
return
}

func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) (err error) {
func (conn *Connection) writeRequest(w *bufio.Writer, req Request) error {
var packet smallWBuf
req := newAuthRequest(conn.opts.User, string(scramble))
err = pack(&packet, newEncoder(&packet), 0, req, ignoreStreamId, conn.Schema)
err := pack(&packet, newEncoder(&packet), 0, req, ignoreStreamId, conn.Schema)

if err != nil {
return errors.New("auth: pack error " + err.Error())
return fmt.Errorf("pack error: %w", err)
}
if err := write(w, packet.b); err != nil {
return errors.New("auth: write error " + err.Error())
if err = write(w, packet.b); err != nil {
return fmt.Errorf("write error: %w", err)
}
if err = w.Flush(); err != nil {
return errors.New("auth: flush error " + err.Error())
return fmt.Errorf("flush error: %w", err)
}
return
return err
}

func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) error {
req := newAuthRequest(conn.opts.User, string(scramble))

err := conn.writeRequest(w, req)
if err != nil {
return fmt.Errorf("auth: %w", err)
}

return nil
}

func (conn *Connection) readAuthResponse(r io.Reader) (err error) {
func (conn *Connection) writeIdRequest(w *bufio.Writer, protocolInfo ProtocolInfo) error {
req := NewIdRequest(protocolInfo)

err := conn.writeRequest(w, req)
if err != nil {
return fmt.Errorf("identify: %w", err)
}

return nil
}

func (conn *Connection) readResponse(r io.Reader) (Response, error) {
respBytes, err := conn.read(r)
if err != nil {
return errors.New("auth: read error " + err.Error())
return Response{}, fmt.Errorf("read error: %w", err)
}

resp := Response{buf: smallBuf{b: respBytes}}
err = resp.decodeHeader(conn.dec)
if err != nil {
return errors.New("auth: decode response header error " + err.Error())
return resp, fmt.Errorf("decode response header error: %w", err)
}
err = resp.decodeBody()
if err != nil {
switch err.(type) {
case Error:
return err
return resp, err
default:
return errors.New("auth: decode response body error " + err.Error())
return resp, fmt.Errorf("decode response body error: %w", err)
}
}
return
return resp, nil
}

func (conn *Connection) readAuthResponse(r io.Reader) error {
_, err := conn.readResponse(r)
if err != nil {
return fmt.Errorf("auth: %w", err)
}

return nil
}

func (conn *Connection) readIdResponse(r io.Reader) (Response, error) {
resp, err := conn.readResponse(r)
if err != nil {
return resp, fmt.Errorf("identify: %w", err)
}

return resp, nil
}

func (conn *Connection) createConnection(reconnect bool) (err error) {
Expand Down Expand Up @@ -1163,3 +1232,89 @@ func (conn *Connection) NewStream() (*Stream, error) {
Conn: conn,
}, nil
}

// checkProtocolInfo checks that expected protocol version is
// and protocol features are supported.
func checkProtocolInfo(expected ProtocolInfo, actual ProtocolInfo) error {
var found bool
var missingFeatures []ProtocolFeature

if expected.Version > actual.Version {
return fmt.Errorf("protocol version %d is not supported", expected.Version)
}

// It seems that iterating over a small list is way faster
// than building a map: https://stackoverflow.com/a/52710077/11646599
for _, expectedFeature := range expected.Features {
found = false
for _, actualFeature := range actual.Features {
if expectedFeature == actualFeature {
found = true
}
}
if !found {
missingFeatures = append(missingFeatures, expectedFeature)
}
}

if len(missingFeatures) == 1 {
return fmt.Errorf("protocol feature %s is not supported", missingFeatures[0])
}

if len(missingFeatures) > 1 {
var sarr []string
for _, missingFeature := range missingFeatures {
sarr = append(sarr, missingFeature.String())
}
return fmt.Errorf("protocol features %s are not supported", strings.Join(sarr, ", "))
}

return nil
}

// identify sends info about client protocol, receives info
// about server protocol in response and stores it in the connection.
func (conn *Connection) identify(w *bufio.Writer, r *bufio.Reader) error {
var ok bool

werr := conn.writeIdRequest(w, clientProtocolInfo)
if werr != nil {
return werr
}

resp, rerr := conn.readIdResponse(r)
if rerr != nil {
if resp.Code == ErrUnknownRequestType {
// IPROTO_ID requests are not supported by server.
return nil
}

return rerr
}

if len(resp.Data) == 0 {
return fmt.Errorf("identify: unexpected response: no data")
}

conn.serverProtocolInfo, ok = resp.Data[0].(ProtocolInfo)
if !ok {
return fmt.Errorf("identify: unexpected response: wrong data")
}

return nil
}

// ServerProtocolVersion returns protocol version and protocol features
// supported by connected Tarantool server. Beware that values might be
// outdated if connection is in a disconnected state.
// Since 1.10.0
func (conn *Connection) ServerProtocolInfo() ProtocolInfo {
return conn.serverProtocolInfo.Clone()
}

// ClientProtocolVersion returns protocol version and protocol features
// supported by Go connection client.
// Since 1.10.0
func (conn *Connection) ClientProtocolInfo() ProtocolInfo {
return clientProtocolInfo.Clone()
}
2 changes: 1 addition & 1 deletion connection_pool/connection_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func ConnectWithOpts(addrs []string, connOpts tarantool.Opts, opts OptsPool) (co

connPool = &ConnectionPool{
addrs: make([]string, 0, len(addrs)),
connOpts: connOpts,
connOpts: connOpts.Clone(),
opts: opts,
state: unknownState,
done: make(chan struct{}),
Expand Down
Loading