15
15
package openssl
16
16
17
17
import (
18
+ "context"
18
19
"errors"
19
20
"net"
20
21
"time"
@@ -89,8 +90,55 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
89
90
// parameters.
90
91
func DialTimeout (network , addr string , timeout time.Duration , ctx * Ctx ,
91
92
flags DialFlags ) (* Conn , error ) {
92
- d := net.Dialer {Timeout : timeout }
93
- return dialSession (d , network , addr , ctx , flags , nil )
93
+ host , err := parseHost (addr )
94
+ if err != nil {
95
+ return nil , err
96
+ }
97
+
98
+ conn , err := net .DialTimeout (network , addr , timeout )
99
+ if err != nil {
100
+ return nil , err
101
+ }
102
+ ctx , err = prepareCtx (ctx )
103
+ if err != nil {
104
+ conn .Close ()
105
+ return nil , err
106
+ }
107
+ client , err := createSession (conn , flags , host , ctx , nil )
108
+ if err != nil {
109
+ conn .Close ()
110
+ }
111
+ return client , err
112
+ }
113
+
114
+ // DialContext acts like Dial but takes a context for network dial.
115
+ //
116
+ // The context includes only network dial. It does not include OpenSSL calls.
117
+ //
118
+ // See func Dial for a description of the network, addr, ctx and flags
119
+ // parameters.
120
+ func DialContext (context context.Context , network , addr string ,
121
+ ctx * Ctx , flags DialFlags ) (* Conn , error ) {
122
+ host , err := parseHost (addr )
123
+ if err != nil {
124
+ return nil , err
125
+ }
126
+
127
+ dialer := net.Dialer {}
128
+ conn , err := dialer .DialContext (context , network , addr )
129
+ if err != nil {
130
+ return nil , err
131
+ }
132
+ ctx , err = prepareCtx (ctx )
133
+ if err != nil {
134
+ conn .Close ()
135
+ return nil , err
136
+ }
137
+ client , err := createSession (conn , flags , host , ctx , nil )
138
+ if err != nil {
139
+ conn .Close ()
140
+ }
141
+ return client , err
94
142
}
95
143
96
144
// DialSession will connect to network/address and then wrap the corresponding
@@ -108,59 +156,76 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
108
156
// can be retrieved from the GetSession method on the Conn.
109
157
func DialSession (network , addr string , ctx * Ctx , flags DialFlags ,
110
158
session []byte ) (* Conn , error ) {
111
- var d net.Dialer
112
- return dialSession (d , network , addr , ctx , flags , session )
113
- }
114
-
115
- func dialSession (d net.Dialer , network , addr string , ctx * Ctx , flags DialFlags ,
116
- session []byte ) (* Conn , error ) {
117
- host , _ , err := net .SplitHostPort (addr )
159
+ host , err := parseHost (addr )
118
160
if err != nil {
119
161
return nil , err
120
162
}
121
- if ctx == nil {
122
- var err error
123
- ctx , err = NewCtx ()
124
- if err != nil {
125
- return nil , err
126
- }
127
- // TODO: use operating system default certificate chain?
128
- }
129
163
130
- c , err := d .Dial (network , addr )
164
+ conn , err := net .Dial (network , addr )
131
165
if err != nil {
132
166
return nil , err
133
167
}
134
- conn , err := Client ( c , ctx )
168
+ ctx , err = prepareCtx ( ctx )
135
169
if err != nil {
136
- c .Close ()
170
+ conn .Close ()
137
171
return nil , err
138
172
}
139
- if session != nil {
140
- err := conn .setSession (session )
141
- if err != nil {
142
- c .Close ()
143
- return nil , err
144
- }
173
+ client , err := createSession (conn , flags , host , ctx , session )
174
+ if err != nil {
175
+ conn .Close ()
176
+ }
177
+ return client , err
178
+ }
179
+
180
+ func prepareCtx (ctx * Ctx ) (* Ctx , error ) {
181
+ if ctx == nil {
182
+ return NewCtx ()
145
183
}
184
+ return ctx , nil
185
+ }
186
+
187
+ func parseHost (addr string ) (string , error ) {
188
+ host , _ , err := net .SplitHostPort (addr )
189
+ return host , err
190
+ }
191
+
192
+ func handshake (conn * Conn , host string , flags DialFlags ) error {
193
+ var err error
146
194
if flags & DisableSNI == 0 {
147
195
err = conn .SetTlsExtHostName (host )
148
196
if err != nil {
149
- conn .Close ()
150
- return nil , err
197
+ return err
151
198
}
152
199
}
153
200
err = conn .Handshake ()
154
201
if err != nil {
155
- conn .Close ()
156
- return nil , err
202
+ return err
157
203
}
158
204
if flags & InsecureSkipHostVerification == 0 {
159
205
err = conn .VerifyHostname (host )
206
+ if err != nil {
207
+ return err
208
+ }
209
+ }
210
+ return nil
211
+ }
212
+
213
+ func createSession (c net.Conn , flags DialFlags , host string , ctx * Ctx ,
214
+ session []byte ) (* Conn , error ) {
215
+ conn , err := Client (c , ctx )
216
+ if err != nil {
217
+ return nil , err
218
+ }
219
+ if session != nil {
220
+ err := conn .setSession (session )
160
221
if err != nil {
161
222
conn .Close ()
162
223
return nil , err
163
224
}
164
225
}
226
+ if err := handshake (conn , host , flags ); err != nil {
227
+ conn .Close ()
228
+ return nil , err
229
+ }
165
230
return conn , nil
166
231
}
0 commit comments