Skip to content

Commit bdadb81

Browse files
committed
dial: add DialContext function
In order to replace timeouts with contexts in `Connect` instance creation (go-tarantool), I need a `DialContext` function. It accepts context, and cancels, if context is canceled by user. Part of tarantool/go-tarantool#136
1 parent b452431 commit bdadb81

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

net.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package openssl
1616

1717
import (
18+
"context"
1819
"errors"
1920
"net"
2021
"time"
@@ -90,7 +91,19 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
9091
func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
9192
flags DialFlags) (*Conn, error) {
9293
d := net.Dialer{Timeout: timeout}
93-
return dialSession(d, network, addr, ctx, flags, nil)
94+
return dialSession(d, context.Background(), network, addr, ctx, flags, nil)
95+
}
96+
97+
// DialContext acts like Dial but takes a context for network dial.
98+
//
99+
// The context includes only network dial. It does not include OpenSSL calls.
100+
//
101+
// See func Dial for a description of the network, addr, ctx and flags
102+
// parameters.
103+
func DialContext(context context.Context, network, addr string,
104+
ctx *Ctx, flags DialFlags) (*Conn, error) {
105+
d := net.Dialer{}
106+
return dialSession(d, context, network, addr, ctx, flags, nil)
94107
}
95108

96109
// DialSession will connect to network/address and then wrap the corresponding
@@ -109,11 +122,11 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
109122
func DialSession(network, addr string, ctx *Ctx, flags DialFlags,
110123
session []byte) (*Conn, error) {
111124
var d net.Dialer
112-
return dialSession(d, network, addr, ctx, flags, session)
125+
return dialSession(d, context.Background(), network, addr, ctx, flags, session)
113126
}
114127

115-
func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
116-
session []byte) (*Conn, error) {
128+
func dialSession(d net.Dialer, context context.Context, network, addr string,
129+
ctx *Ctx, flags DialFlags, session []byte) (*Conn, error) {
117130
host, _, err := net.SplitHostPort(addr)
118131
if err != nil {
119132
return nil, err
@@ -127,7 +140,7 @@ func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
127140
// TODO: use operating system default certificate chain?
128141
}
129142

130-
c, err := d.Dial(network, addr)
143+
c, err := d.DialContext(context, network, addr)
131144
if err != nil {
132145
return nil, err
133146
}

0 commit comments

Comments
 (0)