15
15
package openssl
16
16
17
17
import (
18
+ "context"
18
19
"errors"
19
20
"net"
20
21
"time"
@@ -90,7 +91,37 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
90
91
func DialTimeout (network , addr string , timeout time.Duration , ctx * Ctx ,
91
92
flags DialFlags ) (* Conn , error ) {
92
93
d := net.Dialer {Timeout : timeout }
93
- return dialSession (d , network , addr , ctx , flags , nil )
94
+ ctx , err := prepareCtx (ctx )
95
+ if err != nil {
96
+ return nil , err
97
+ }
98
+ conn , err := createConnection (context .Background (), d , network , addr )
99
+ if err != nil {
100
+ return nil , err
101
+ }
102
+ defer conn .Close ()
103
+ return createSession (conn , flags , addr , ctx , nil )
104
+ }
105
+
106
+ // DialContext acts like Dial but takes a context for network dial.
107
+ //
108
+ // The context includes only network dial. It does not include OpenSSL calls.
109
+ //
110
+ // See func Dial for a description of the network, addr, ctx and flags
111
+ // parameters.
112
+ func DialContext (context context.Context , network , addr string ,
113
+ ctx * Ctx , flags DialFlags ) (* Conn , error ) {
114
+ d := net.Dialer {}
115
+ ctx , err := prepareCtx (ctx )
116
+ if err != nil {
117
+ return nil , err
118
+ }
119
+ conn , err := createConnection (context , d , network , addr )
120
+ if err != nil {
121
+ return nil , err
122
+ }
123
+ defer conn .Close ()
124
+ return createSession (conn , flags , addr , ctx , nil )
94
125
}
95
126
96
127
// DialSession will connect to network/address and then wrap the corresponding
@@ -109,15 +140,19 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
109
140
func DialSession (network , addr string , ctx * Ctx , flags DialFlags ,
110
141
session []byte ) (* Conn , error ) {
111
142
var d net.Dialer
112
- return dialSession (d , network , addr , ctx , flags , session )
113
- }
114
-
115
- func dialSession (d net.Dialer , network , addr string , ctx * Ctx , flags DialFlags ,
116
- session []byte ) (* Conn , error ) {
117
- host , _ , err := net .SplitHostPort (addr )
143
+ ctx , err := prepareCtx (ctx )
118
144
if err != nil {
119
145
return nil , err
120
146
}
147
+ conn , err := createConnection (context .Background (), d , network , addr )
148
+ if err != nil {
149
+ return nil , err
150
+ }
151
+ defer conn .Close ()
152
+ return createSession (conn , flags , addr , ctx , session )
153
+ }
154
+
155
+ func prepareCtx (ctx * Ctx ) (* Ctx , error ) {
121
156
if ctx == nil {
122
157
var err error
123
158
ctx , err = NewCtx ()
@@ -126,41 +161,58 @@ func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
126
161
}
127
162
// TODO: use operating system default certificate chain?
128
163
}
164
+ return ctx , nil
165
+ }
129
166
130
- c , err := d .Dial (network , addr )
167
+ func createConnection (context context.Context , d net.Dialer , network ,
168
+ addr string ) (net.Conn , error ) {
169
+ c , err := d .DialContext (context , network , addr )
131
170
if err != nil {
132
171
return nil , err
133
172
}
134
- conn , err := Client (c , ctx )
135
- if err != nil {
136
- c .Close ()
137
- return nil , err
138
- }
139
- if session != nil {
140
- err := conn .setSession (session )
141
- if err != nil {
142
- c .Close ()
143
- return nil , err
144
- }
145
- }
173
+ return c , nil
174
+ }
175
+
176
+ func applyFlags (conn * Conn , host string , flags DialFlags ) error {
177
+ var err error
146
178
if flags & DisableSNI == 0 {
147
179
err = conn .SetTlsExtHostName (host )
148
180
if err != nil {
149
- conn .Close ()
150
- return nil , err
181
+ return err
151
182
}
152
183
}
153
184
err = conn .Handshake ()
154
185
if err != nil {
155
- conn .Close ()
156
- return nil , err
186
+ return err
157
187
}
158
188
if flags & InsecureSkipHostVerification == 0 {
159
189
err = conn .VerifyHostname (host )
160
190
if err != nil {
161
- conn .Close ()
191
+ return err
192
+ }
193
+ }
194
+ return nil
195
+ }
196
+
197
+ func createSession (c net.Conn , flags DialFlags , addr string , ctx * Ctx ,
198
+ session []byte ) (* Conn , error ) {
199
+ host , _ , err := net .SplitHostPort (addr )
200
+ if err != nil {
201
+ return nil , err
202
+ }
203
+ conn , err := Client (c , ctx )
204
+ if err != nil {
205
+ return nil , err
206
+ }
207
+ defer conn .Close ()
208
+ if session != nil {
209
+ err := conn .setSession (session )
210
+ if err != nil {
162
211
return nil , err
163
212
}
164
213
}
214
+ if err := applyFlags (conn , host , flags ); err != nil {
215
+ return nil , err
216
+ }
165
217
return conn , nil
166
218
}
0 commit comments