Skip to content

Commit ae4549a

Browse files
committed
dial: add DialContext function
In order to replace timeouts with contexts in `Connect` instance creation (go-tarantool), I need a `DialContext` function. It accepts context, and cancels, if context is canceled by user. Part of tarantool/go-tarantool#136
1 parent b452431 commit ae4549a

File tree

2 files changed

+163
-24
lines changed

2 files changed

+163
-24
lines changed

net.go

Lines changed: 86 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package openssl
1616

1717
import (
18+
"context"
1819
"errors"
1920
"net"
2021
"time"
@@ -90,7 +91,43 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
9091
func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
9192
flags DialFlags) (*Conn, error) {
9293
d := net.Dialer{Timeout: timeout}
93-
return dialSession(d, network, addr, ctx, flags, nil)
94+
ctx, err := prepareCtx(ctx)
95+
if err != nil {
96+
return nil, err
97+
}
98+
conn, err := createConnection(context.Background(), d, network, addr)
99+
if err != nil {
100+
return nil, err
101+
}
102+
client, err := createSession(conn, flags, addr, ctx, nil)
103+
if err != nil {
104+
conn.Close()
105+
}
106+
return client, err
107+
}
108+
109+
// DialContext acts like Dial but takes a context for network dial.
110+
//
111+
// The context includes only network dial. It does not include OpenSSL calls.
112+
//
113+
// See func Dial for a description of the network, addr, ctx and flags
114+
// parameters.
115+
func DialContext(context context.Context, network, addr string,
116+
ctx *Ctx, flags DialFlags) (*Conn, error) {
117+
d := net.Dialer{}
118+
ctx, err := prepareCtx(ctx)
119+
if err != nil {
120+
return nil, err
121+
}
122+
conn, err := createConnection(context, d, network, addr)
123+
if err != nil {
124+
return nil, err
125+
}
126+
client, err := createSession(conn, flags, addr, ctx, nil)
127+
if err != nil {
128+
conn.Close()
129+
}
130+
return client, err
94131
}
95132

96133
// DialSession will connect to network/address and then wrap the corresponding
@@ -109,15 +146,22 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
109146
func DialSession(network, addr string, ctx *Ctx, flags DialFlags,
110147
session []byte) (*Conn, error) {
111148
var d net.Dialer
112-
return dialSession(d, network, addr, ctx, flags, session)
113-
}
114-
115-
func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
116-
session []byte) (*Conn, error) {
117-
host, _, err := net.SplitHostPort(addr)
149+
ctx, err := prepareCtx(ctx)
118150
if err != nil {
119151
return nil, err
120152
}
153+
conn, err := createConnection(context.Background(), d, network, addr)
154+
if err != nil {
155+
return nil, err
156+
}
157+
client, err := createSession(conn, flags, addr, ctx, session)
158+
if err != nil {
159+
conn.Close()
160+
}
161+
return client, err
162+
}
163+
164+
func prepareCtx(ctx *Ctx) (*Ctx, error) {
121165
if ctx == nil {
122166
var err error
123167
ctx, err = NewCtx()
@@ -126,41 +170,59 @@ func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
126170
}
127171
// TODO: use operating system default certificate chain?
128172
}
173+
return ctx, nil
174+
}
129175

130-
c, err := d.Dial(network, addr)
176+
func createConnection(context context.Context, d net.Dialer, network,
177+
addr string) (net.Conn, error) {
178+
c, err := d.DialContext(context, network, addr)
131179
if err != nil {
132180
return nil, err
133181
}
134-
conn, err := Client(c, ctx)
135-
if err != nil {
136-
c.Close()
137-
return nil, err
138-
}
139-
if session != nil {
140-
err := conn.setSession(session)
141-
if err != nil {
142-
c.Close()
143-
return nil, err
144-
}
145-
}
182+
return c, nil
183+
}
184+
185+
func applyFlags(conn *Conn, host string, flags DialFlags) error {
186+
var err error
146187
if flags&DisableSNI == 0 {
147188
err = conn.SetTlsExtHostName(host)
148189
if err != nil {
149-
conn.Close()
150-
return nil, err
190+
return err
151191
}
152192
}
153193
err = conn.Handshake()
154194
if err != nil {
155-
conn.Close()
156-
return nil, err
195+
return err
157196
}
158197
if flags&InsecureSkipHostVerification == 0 {
159198
err = conn.VerifyHostname(host)
199+
if err != nil {
200+
return err
201+
}
202+
}
203+
return nil
204+
}
205+
206+
func createSession(c net.Conn, flags DialFlags, addr string, ctx *Ctx,
207+
session []byte) (*Conn, error) {
208+
host, _, err := net.SplitHostPort(addr)
209+
if err != nil {
210+
return nil, err
211+
}
212+
conn, err := Client(c, ctx)
213+
if err != nil {
214+
return nil, err
215+
}
216+
if session != nil {
217+
err := conn.setSession(session)
160218
if err != nil {
161219
conn.Close()
162220
return nil, err
163221
}
164222
}
223+
if err := applyFlags(conn, host, flags); err != nil {
224+
conn.Close()
225+
return nil, err
226+
}
165227
return conn, nil
166228
}

net_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package openssl
2+
3+
import (
4+
"context"
5+
"io"
6+
"io/ioutil"
7+
"net"
8+
"testing"
9+
"time"
10+
)
11+
12+
func sslConnect(t *testing.T, ssl_listener net.Listener) {
13+
for {
14+
conn, err := ssl_listener.Accept()
15+
if err != nil {
16+
t.Errorf("failed accept: %s", err)
17+
continue
18+
}
19+
go func() {
20+
io.Copy(ioutil.Discard, io.LimitReader(conn, 1024))
21+
}()
22+
}
23+
}
24+
25+
func TestDial(t *testing.T) {
26+
ctx := getCtx(t)
27+
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
28+
t.Fatal(err)
29+
}
30+
ssl_listener, err := Listen("tcp", "localhost:0", ctx)
31+
32+
go sslConnect(t, ssl_listener)
33+
34+
client, err := Dial(ssl_listener.Addr().Network(),
35+
ssl_listener.Addr().String(), ctx, InsecureSkipHostVerification)
36+
37+
if err != nil {
38+
t.Fatalf("unexpected err: %v", err)
39+
}
40+
if client.is_shutdown {
41+
t.Fatal("client is closed after creation")
42+
}
43+
}
44+
45+
func TestDialTimeout(t *testing.T) {
46+
ctx := getCtx(t)
47+
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
48+
t.Fatal(err)
49+
}
50+
ssl_listener, err := Listen("tcp", "localhost:0", ctx)
51+
52+
go sslConnect(t, ssl_listener)
53+
54+
client, err := DialTimeout(ssl_listener.Addr().Network(),
55+
ssl_listener.Addr().String(), time.Nanosecond, ctx, 0)
56+
if client != nil || err == nil {
57+
t.Fatalf("expected error")
58+
}
59+
}
60+
61+
func TestDialContext(t *testing.T) {
62+
ctx := getCtx(t)
63+
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
64+
t.Fatal(err)
65+
}
66+
ssl_listener, err := Listen("tcp", "localhost:0", ctx)
67+
68+
go sslConnect(t, ssl_listener)
69+
70+
cancelCtx, cancel := context.WithCancel(context.Background())
71+
cancel()
72+
client, err := DialContext(cancelCtx, ssl_listener.Addr().Network(),
73+
ssl_listener.Addr().String(), ctx, 0)
74+
if client != nil || err == nil {
75+
t.Fatalf("expected error")
76+
}
77+
}

0 commit comments

Comments
 (0)