15
15
package openssl
16
16
17
17
import (
18
+ "context"
18
19
"errors"
19
20
"net"
20
21
"time"
@@ -90,7 +91,33 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
90
91
func DialTimeout (network , addr string , timeout time.Duration , ctx * Ctx ,
91
92
flags DialFlags ) (* Conn , error ) {
92
93
d := net.Dialer {Timeout : timeout }
93
- return dialSession (d , network , addr , ctx , flags , nil )
94
+ if err := prepareCtx (ctx ); err != nil {
95
+ return nil , err
96
+ }
97
+ conn , err := createConnection (d , context .Background (), network , addr )
98
+ if err != nil {
99
+ return nil , err
100
+ }
101
+ return createSession (conn , flags , addr , ctx , nil )
102
+ }
103
+
104
+ // DialContext acts like Dial but takes a context for network dial.
105
+ //
106
+ // The context includes only network dial. It does not include OpenSSL calls.
107
+ //
108
+ // See func Dial for a description of the network, addr, ctx and flags
109
+ // parameters.
110
+ func DialContext (context context.Context , network , addr string ,
111
+ ctx * Ctx , flags DialFlags ) (* Conn , error ) {
112
+ d := net.Dialer {}
113
+ if err := prepareCtx (ctx ); err != nil {
114
+ return nil , err
115
+ }
116
+ conn , err := createConnection (d , context , network , addr )
117
+ if err != nil {
118
+ return nil , err
119
+ }
120
+ return createSession (conn , flags , addr , ctx , nil )
94
121
}
95
122
96
123
// DialSession will connect to network/address and then wrap the corresponding
@@ -109,58 +136,81 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
109
136
func DialSession (network , addr string , ctx * Ctx , flags DialFlags ,
110
137
session []byte ) (* Conn , error ) {
111
138
var d net.Dialer
112
- return dialSession (d , network , addr , ctx , flags , session )
113
- }
114
-
115
- func dialSession (d net.Dialer , network , addr string , ctx * Ctx , flags DialFlags ,
116
- session []byte ) (* Conn , error ) {
117
- host , _ , err := net .SplitHostPort (addr )
139
+ if err := prepareCtx (ctx ); err != nil {
140
+ return nil , err
141
+ }
142
+ conn , err := createConnection (d , context .Background (), network , addr )
118
143
if err != nil {
119
144
return nil , err
120
145
}
146
+ return createSession (conn , flags , addr , ctx , session )
147
+ }
148
+
149
+ func prepareCtx (ctx * Ctx ) error {
121
150
if ctx == nil {
122
151
var err error
123
152
ctx , err = NewCtx ()
124
153
if err != nil {
125
- return nil , err
154
+ return err
126
155
}
127
156
// TODO: use operating system default certificate chain?
128
157
}
158
+ return nil
159
+ }
129
160
130
- c , err := d .Dial (network , addr )
161
+ func createConnection (d net.Dialer , context context.Context , network ,
162
+ addr string ) (net.Conn , error ) {
163
+ c , err := d .DialContext (context , network , addr )
131
164
if err != nil {
132
165
return nil , err
133
166
}
134
- conn , err := Client (c , ctx )
135
- if err != nil {
136
- c .Close ()
137
- return nil , err
138
- }
139
- if session != nil {
140
- err := conn .setSession (session )
141
- if err != nil {
142
- c .Close ()
143
- return nil , err
144
- }
145
- }
167
+ return c , nil
168
+ }
169
+
170
+ func applyFlags (conn * Conn , host string , flags DialFlags ) error {
171
+ var err error
146
172
if flags & DisableSNI == 0 {
147
173
err = conn .SetTlsExtHostName (host )
148
174
if err != nil {
149
175
conn .Close ()
150
- return nil , err
176
+ return err
151
177
}
152
178
}
153
179
err = conn .Handshake ()
154
180
if err != nil {
155
181
conn .Close ()
156
- return nil , err
182
+ return err
157
183
}
158
184
if flags & InsecureSkipHostVerification == 0 {
159
185
err = conn .VerifyHostname (host )
160
186
if err != nil {
161
187
conn .Close ()
188
+ return err
189
+ }
190
+ }
191
+ return nil
192
+ }
193
+
194
+ func createSession (c net.Conn , flags DialFlags , addr string , ctx * Ctx ,
195
+ session []byte ) (* Conn , error ) {
196
+ host , _ , err := net .SplitHostPort (addr )
197
+ if err != nil {
198
+ return nil , err
199
+ }
200
+ conn , err := Client (c , ctx )
201
+ if err != nil {
202
+ c .Close ()
203
+ return nil , err
204
+ }
205
+ if session != nil {
206
+ err := conn .setSession (session )
207
+ if err != nil {
208
+ c .Close ()
162
209
return nil , err
163
210
}
164
211
}
212
+ if err := applyFlags (conn , host , flags ); err != nil {
213
+ return nil , err
214
+ }
165
215
return conn , nil
166
216
}
0 commit comments