Skip to content

Commit be9a218

Browse files
committed
dial: add DialContext function
In order to replace timeouts with contexts in `Connect` instance creation (go-tarantool), I need a `DialContext` function. It accepts context, and cancels, if context is canceled by user. Part of tarantool/go-tarantool#136
1 parent b452431 commit be9a218

File tree

2 files changed

+110
-23
lines changed

2 files changed

+110
-23
lines changed

net.go

Lines changed: 73 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package openssl
1616

1717
import (
18+
"context"
1819
"errors"
1920
"net"
2021
"time"
@@ -90,7 +91,33 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
9091
func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
9192
flags DialFlags) (*Conn, error) {
9293
d := net.Dialer{Timeout: timeout}
93-
return dialSession(d, network, addr, ctx, flags, nil)
94+
if err := prepareCtx(ctx); err != nil {
95+
return nil, err
96+
}
97+
conn, err := createConnection(d, context.Background(), network, addr)
98+
if err != nil {
99+
return nil, err
100+
}
101+
return createSession(conn, flags, addr, ctx, nil)
102+
}
103+
104+
// DialContext acts like Dial but takes a context for network dial.
105+
//
106+
// The context includes only network dial. It does not include OpenSSL calls.
107+
//
108+
// See func Dial for a description of the network, addr, ctx and flags
109+
// parameters.
110+
func DialContext(context context.Context, network, addr string,
111+
ctx *Ctx, flags DialFlags) (*Conn, error) {
112+
d := net.Dialer{}
113+
if err := prepareCtx(ctx); err != nil {
114+
return nil, err
115+
}
116+
conn, err := createConnection(d, context, network, addr)
117+
if err != nil {
118+
return nil, err
119+
}
120+
return createSession(conn, flags, addr, ctx, nil)
94121
}
95122

96123
// DialSession will connect to network/address and then wrap the corresponding
@@ -109,58 +136,81 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
109136
func DialSession(network, addr string, ctx *Ctx, flags DialFlags,
110137
session []byte) (*Conn, error) {
111138
var d net.Dialer
112-
return dialSession(d, network, addr, ctx, flags, session)
113-
}
114-
115-
func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
116-
session []byte) (*Conn, error) {
117-
host, _, err := net.SplitHostPort(addr)
139+
if err := prepareCtx(ctx); err != nil {
140+
return nil, err
141+
}
142+
conn, err := createConnection(d, context.Background(), network, addr)
118143
if err != nil {
119144
return nil, err
120145
}
146+
return createSession(conn, flags, addr, ctx, session)
147+
}
148+
149+
func prepareCtx(ctx *Ctx) error {
121150
if ctx == nil {
122151
var err error
123152
ctx, err = NewCtx()
124153
if err != nil {
125-
return nil, err
154+
return err
126155
}
127156
// TODO: use operating system default certificate chain?
128157
}
158+
return nil
159+
}
129160

130-
c, err := d.Dial(network, addr)
161+
func createConnection(d net.Dialer, context context.Context, network,
162+
addr string) (net.Conn, error) {
163+
c, err := d.DialContext(context, network, addr)
131164
if err != nil {
132165
return nil, err
133166
}
134-
conn, err := Client(c, ctx)
135-
if err != nil {
136-
c.Close()
137-
return nil, err
138-
}
139-
if session != nil {
140-
err := conn.setSession(session)
141-
if err != nil {
142-
c.Close()
143-
return nil, err
144-
}
145-
}
167+
return c, nil
168+
}
169+
170+
func applyFlags(conn *Conn, host string, flags DialFlags) error {
171+
var err error
146172
if flags&DisableSNI == 0 {
147173
err = conn.SetTlsExtHostName(host)
148174
if err != nil {
149175
conn.Close()
150-
return nil, err
176+
return err
151177
}
152178
}
153179
err = conn.Handshake()
154180
if err != nil {
155181
conn.Close()
156-
return nil, err
182+
return err
157183
}
158184
if flags&InsecureSkipHostVerification == 0 {
159185
err = conn.VerifyHostname(host)
160186
if err != nil {
161187
conn.Close()
188+
return err
189+
}
190+
}
191+
return nil
192+
}
193+
194+
func createSession(c net.Conn, flags DialFlags, addr string, ctx *Ctx,
195+
session []byte) (*Conn, error) {
196+
host, _, err := net.SplitHostPort(addr)
197+
if err != nil {
198+
return nil, err
199+
}
200+
conn, err := Client(c, ctx)
201+
if err != nil {
202+
c.Close()
203+
return nil, err
204+
}
205+
if session != nil {
206+
err := conn.setSession(session)
207+
if err != nil {
208+
c.Close()
162209
return nil, err
163210
}
164211
}
212+
if err := applyFlags(conn, host, flags); err != nil {
213+
return nil, err
214+
}
165215
return conn, nil
166216
}

net_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package openssl
2+
3+
import (
4+
"context"
5+
"net"
6+
"strings"
7+
"testing"
8+
"time"
9+
)
10+
11+
func TestDialTimeout(t *testing.T) {
12+
tcp_listener, err := net.Listen("tcp", "localhost:0")
13+
if err != nil {
14+
t.Fatal(err)
15+
}
16+
ctx, _ := NewCtx()
17+
_, err = DialTimeout(tcp_listener.Addr().Network(),
18+
tcp_listener.Addr().String(), time.Nanosecond, ctx, 0)
19+
if err == nil || !strings.Contains(err.Error(), "i/o timeout") {
20+
t.Fatalf("unexpected error, got %v", err)
21+
}
22+
}
23+
24+
func TestDialContext(t *testing.T) {
25+
tcp_listener, err := net.Listen("tcp", "localhost:0")
26+
if err != nil {
27+
t.Fatal(err)
28+
}
29+
cancelCtx, cancel := context.WithCancel(context.Background())
30+
ctx, _ := NewCtx()
31+
cancel()
32+
_, err = DialContext(cancelCtx, tcp_listener.Addr().Network(),
33+
tcp_listener.Addr().String(), ctx, 0)
34+
if err == nil || !strings.Contains(err.Error(), "operation was canceled") {
35+
t.Fatalf("unexpected error, got %v", err)
36+
}
37+
}

0 commit comments

Comments
 (0)