15
15
package openssl
16
16
17
17
import (
18
+ "context"
18
19
"errors"
19
20
"net"
20
21
"time"
@@ -90,7 +91,43 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
90
91
func DialTimeout (network , addr string , timeout time.Duration , ctx * Ctx ,
91
92
flags DialFlags ) (* Conn , error ) {
92
93
d := net.Dialer {Timeout : timeout }
93
- return dialSession (d , network , addr , ctx , flags , nil )
94
+ ctx , err := prepareCtx (ctx )
95
+ if err != nil {
96
+ return nil , err
97
+ }
98
+ conn , err := createConnection (context .Background (), d , network , addr )
99
+ if err != nil {
100
+ return nil , err
101
+ }
102
+ client , err := createSession (conn , flags , addr , ctx , nil )
103
+ if err != nil {
104
+ conn .Close ()
105
+ }
106
+ return client , err
107
+ }
108
+
109
+ // DialContext acts like Dial but takes a context for network dial.
110
+ //
111
+ // The context includes only network dial. It does not include OpenSSL calls.
112
+ //
113
+ // See func Dial for a description of the network, addr, ctx and flags
114
+ // parameters.
115
+ func DialContext (context context.Context , network , addr string ,
116
+ ctx * Ctx , flags DialFlags ) (* Conn , error ) {
117
+ d := net.Dialer {}
118
+ ctx , err := prepareCtx (ctx )
119
+ if err != nil {
120
+ return nil , err
121
+ }
122
+ conn , err := createConnection (context , d , network , addr )
123
+ if err != nil {
124
+ return nil , err
125
+ }
126
+ client , err := createSession (conn , flags , addr , ctx , nil )
127
+ if err != nil {
128
+ conn .Close ()
129
+ }
130
+ return client , err
94
131
}
95
132
96
133
// DialSession will connect to network/address and then wrap the corresponding
@@ -109,15 +146,22 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
109
146
func DialSession (network , addr string , ctx * Ctx , flags DialFlags ,
110
147
session []byte ) (* Conn , error ) {
111
148
var d net.Dialer
112
- return dialSession (d , network , addr , ctx , flags , session )
113
- }
114
-
115
- func dialSession (d net.Dialer , network , addr string , ctx * Ctx , flags DialFlags ,
116
- session []byte ) (* Conn , error ) {
117
- host , _ , err := net .SplitHostPort (addr )
149
+ ctx , err := prepareCtx (ctx )
118
150
if err != nil {
119
151
return nil , err
120
152
}
153
+ conn , err := createConnection (context .Background (), d , network , addr )
154
+ if err != nil {
155
+ return nil , err
156
+ }
157
+ client , err := createSession (conn , flags , addr , ctx , session )
158
+ if err != nil {
159
+ conn .Close ()
160
+ }
161
+ return client , err
162
+ }
163
+
164
+ func prepareCtx (ctx * Ctx ) (* Ctx , error ) {
121
165
if ctx == nil {
122
166
var err error
123
167
ctx , err = NewCtx ()
@@ -126,41 +170,59 @@ func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
126
170
}
127
171
// TODO: use operating system default certificate chain?
128
172
}
173
+ return ctx , nil
174
+ }
129
175
130
- c , err := d .Dial (network , addr )
176
+ func createConnection (context context.Context , d net.Dialer , network ,
177
+ addr string ) (net.Conn , error ) {
178
+ c , err := d .DialContext (context , network , addr )
131
179
if err != nil {
132
180
return nil , err
133
181
}
134
- conn , err := Client (c , ctx )
135
- if err != nil {
136
- c .Close ()
137
- return nil , err
138
- }
139
- if session != nil {
140
- err := conn .setSession (session )
141
- if err != nil {
142
- c .Close ()
143
- return nil , err
144
- }
145
- }
182
+ return c , nil
183
+ }
184
+
185
+ func applyFlags (conn * Conn , host string , flags DialFlags ) error {
186
+ var err error
146
187
if flags & DisableSNI == 0 {
147
188
err = conn .SetTlsExtHostName (host )
148
189
if err != nil {
149
- conn .Close ()
150
- return nil , err
190
+ return err
151
191
}
152
192
}
153
193
err = conn .Handshake ()
154
194
if err != nil {
155
- conn .Close ()
156
- return nil , err
195
+ return err
157
196
}
158
197
if flags & InsecureSkipHostVerification == 0 {
159
198
err = conn .VerifyHostname (host )
199
+ if err != nil {
200
+ return err
201
+ }
202
+ }
203
+ return nil
204
+ }
205
+
206
+ func createSession (c net.Conn , flags DialFlags , addr string , ctx * Ctx ,
207
+ session []byte ) (* Conn , error ) {
208
+ host , _ , err := net .SplitHostPort (addr )
209
+ if err != nil {
210
+ return nil , err
211
+ }
212
+ conn , err := Client (c , ctx )
213
+ if err != nil {
214
+ return nil , err
215
+ }
216
+ if session != nil {
217
+ err := conn .setSession (session )
160
218
if err != nil {
161
219
conn .Close ()
162
220
return nil , err
163
221
}
164
222
}
223
+ if err := applyFlags (conn , host , flags ); err != nil {
224
+ conn .Close ()
225
+ return nil , err
226
+ }
165
227
return conn , nil
166
228
}
0 commit comments