Skip to content

Commit db9ed8f

Browse files
committed
dial: add DialContext function
In order to replace timeouts with contexts in `Connect` instance creation (go-tarantool), I need a `DialContext` function. It accepts context, and cancels, if context is canceled by user. Part of tarantool/go-tarantool#136
1 parent b452431 commit db9ed8f

File tree

2 files changed

+113
-25
lines changed

2 files changed

+113
-25
lines changed

net.go

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package openssl
1616

1717
import (
18+
"context"
1819
"errors"
1920
"net"
2021
"time"
@@ -90,7 +91,37 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
9091
func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
9192
flags DialFlags) (*Conn, error) {
9293
d := net.Dialer{Timeout: timeout}
93-
return dialSession(d, network, addr, ctx, flags, nil)
94+
ctx, err := prepareCtx(ctx)
95+
if err != nil {
96+
return nil, err
97+
}
98+
conn, err := createConnection(context.Background(), d, network, addr)
99+
if err != nil {
100+
return nil, err
101+
}
102+
defer conn.Close()
103+
return createSession(conn, flags, addr, ctx, nil)
104+
}
105+
106+
// DialContext acts like Dial but takes a context for network dial.
107+
//
108+
// The context includes only network dial. It does not include OpenSSL calls.
109+
//
110+
// See func Dial for a description of the network, addr, ctx and flags
111+
// parameters.
112+
func DialContext(context context.Context, network, addr string,
113+
ctx *Ctx, flags DialFlags) (*Conn, error) {
114+
d := net.Dialer{}
115+
ctx, err := prepareCtx(ctx)
116+
if err != nil {
117+
return nil, err
118+
}
119+
conn, err := createConnection(context, d, network, addr)
120+
if err != nil {
121+
return nil, err
122+
}
123+
defer conn.Close()
124+
return createSession(conn, flags, addr, ctx, nil)
94125
}
95126

96127
// DialSession will connect to network/address and then wrap the corresponding
@@ -109,15 +140,19 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
109140
func DialSession(network, addr string, ctx *Ctx, flags DialFlags,
110141
session []byte) (*Conn, error) {
111142
var d net.Dialer
112-
return dialSession(d, network, addr, ctx, flags, session)
113-
}
114-
115-
func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
116-
session []byte) (*Conn, error) {
117-
host, _, err := net.SplitHostPort(addr)
143+
ctx, err := prepareCtx(ctx)
118144
if err != nil {
119145
return nil, err
120146
}
147+
conn, err := createConnection(context.Background(), d, network, addr)
148+
if err != nil {
149+
return nil, err
150+
}
151+
defer conn.Close()
152+
return createSession(conn, flags, addr, ctx, session)
153+
}
154+
155+
func prepareCtx(ctx *Ctx) (*Ctx, error) {
121156
if ctx == nil {
122157
var err error
123158
ctx, err = NewCtx()
@@ -126,41 +161,58 @@ func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
126161
}
127162
// TODO: use operating system default certificate chain?
128163
}
164+
return ctx, nil
165+
}
129166

130-
c, err := d.Dial(network, addr)
167+
func createConnection(context context.Context, d net.Dialer, network,
168+
addr string) (net.Conn, error) {
169+
c, err := d.DialContext(context, network, addr)
131170
if err != nil {
132171
return nil, err
133172
}
134-
conn, err := Client(c, ctx)
135-
if err != nil {
136-
c.Close()
137-
return nil, err
138-
}
139-
if session != nil {
140-
err := conn.setSession(session)
141-
if err != nil {
142-
c.Close()
143-
return nil, err
144-
}
145-
}
173+
return c, nil
174+
}
175+
176+
func applyFlags(conn *Conn, host string, flags DialFlags) error {
177+
var err error
146178
if flags&DisableSNI == 0 {
147179
err = conn.SetTlsExtHostName(host)
148180
if err != nil {
149-
conn.Close()
150-
return nil, err
181+
return err
151182
}
152183
}
153184
err = conn.Handshake()
154185
if err != nil {
155-
conn.Close()
156-
return nil, err
186+
return err
157187
}
158188
if flags&InsecureSkipHostVerification == 0 {
159189
err = conn.VerifyHostname(host)
160190
if err != nil {
161-
conn.Close()
191+
return err
192+
}
193+
}
194+
return nil
195+
}
196+
197+
func createSession(c net.Conn, flags DialFlags, addr string, ctx *Ctx,
198+
session []byte) (*Conn, error) {
199+
host, _, err := net.SplitHostPort(addr)
200+
if err != nil {
201+
return nil, err
202+
}
203+
conn, err := Client(c, ctx)
204+
if err != nil {
205+
return nil, err
206+
}
207+
defer conn.Close()
208+
if session != nil {
209+
err := conn.setSession(session)
210+
if err != nil {
162211
return nil, err
163212
}
164213
}
214+
if err := applyFlags(conn, host, flags); err != nil {
215+
return nil, err
216+
}
165217
return conn, nil
166218
}

net_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package openssl
2+
3+
import (
4+
"context"
5+
"net"
6+
"testing"
7+
"time"
8+
)
9+
10+
func TestDialTimeout(t *testing.T) {
11+
tcp_listener, err := net.Listen("tcp", "localhost:0")
12+
if err != nil {
13+
t.Fatal(err)
14+
}
15+
ctx, _ := NewCtx()
16+
client, err := DialTimeout(tcp_listener.Addr().Network(),
17+
tcp_listener.Addr().String(), time.Nanosecond, ctx, 0)
18+
if client != nil || err == nil {
19+
t.Fatalf("expected error")
20+
}
21+
}
22+
23+
func TestDialContext(t *testing.T) {
24+
tcp_listener, err := net.Listen("tcp", "localhost:0")
25+
if err != nil {
26+
t.Fatal(err)
27+
}
28+
cancelCtx, cancel := context.WithCancel(context.Background())
29+
ctx, _ := NewCtx()
30+
cancel()
31+
client, err := DialContext(cancelCtx, tcp_listener.Addr().Network(),
32+
tcp_listener.Addr().String(), ctx, 0)
33+
if client != nil || err == nil {
34+
t.Fatalf("expected error")
35+
}
36+
}

0 commit comments

Comments
 (0)