14
14
15
15
class WebSocketBaseClient (WebSocket ):
16
16
def __init__ (self , url , protocols = None , extensions = None ,
17
- heartbeat_freq = None , ssl_options = None , headers = None ):
17
+ heartbeat_freq = None , ssl_options = None , headers = None , exclude_headers = None ):
18
18
"""
19
19
A websocket client that implements :rfc:`6455` and provides a simple
20
20
interface to communicate with a websocket server.
@@ -78,6 +78,8 @@ def __init__(self, url, protocols=None, extensions=None,
78
78
self .resource = None
79
79
self .ssl_options = ssl_options or {}
80
80
self .extra_headers = headers or []
81
+ self .exclude_headers = exclude_headers or []
82
+ self .exclude_headers = [x .lower () for x in self .exclude_headers ]
81
83
82
84
if self .scheme == "wss" :
83
85
# Prevent check_hostname requires server_hostname (ref #187)
@@ -211,7 +213,7 @@ def connect(self):
211
213
# default port is now 443; upgrade self.sender to send ssl
212
214
self .sock = ssl .wrap_socket (self .sock , ** self .ssl_options )
213
215
self ._is_secure = True
214
-
216
+
215
217
self .sock .connect (self .bind_addr )
216
218
217
219
self ._write (self .handshake_request )
@@ -257,14 +259,15 @@ def handshake_headers(self):
257
259
('Sec-WebSocket-Key' , self .key .decode ('utf-8' )),
258
260
('Sec-WebSocket-Version' , str (max (WS_VERSION )))
259
261
]
260
-
262
+
261
263
if self .protocols :
262
264
headers .append (('Sec-WebSocket-Protocol' , ',' .join (self .protocols )))
263
265
264
266
if self .extra_headers :
265
267
headers .extend (self .extra_headers )
266
268
267
- if not any (x for x in headers if x [0 ].lower () == 'origin' ):
269
+ if not any (x for x in headers if x [0 ].lower () == 'origin' ) and \
270
+ 'origin' not in self .exclude_headers :
268
271
269
272
scheme , url = self .url .split (":" , 1 )
270
273
parsed = urlsplit (url , scheme = "http" )
@@ -277,6 +280,8 @@ def handshake_headers(self):
277
280
origin = origin + ':' + str (parsed .port )
278
281
headers .append (('Origin' , origin ))
279
282
283
+ headers = [x for x in headers if x [0 ].lower () not in self .exclude_headers ]
284
+
280
285
return headers
281
286
282
287
@property
0 commit comments