Skip to content

Commit 1599a1d

Browse files
committed
add option to exclude certain headers which would ordinarily be added by default
1 parent b69394a commit 1599a1d

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

ws4py/client/__init__.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
class WebSocketBaseClient(WebSocket):
1616
def __init__(self, url, protocols=None, extensions=None,
17-
heartbeat_freq=None, ssl_options=None, headers=None):
17+
heartbeat_freq=None, ssl_options=None, headers=None, exclude_headers=None):
1818
"""
1919
A websocket client that implements :rfc:`6455` and provides a simple
2020
interface to communicate with a websocket server.
@@ -78,6 +78,8 @@ def __init__(self, url, protocols=None, extensions=None,
7878
self.resource = None
7979
self.ssl_options = ssl_options or {}
8080
self.extra_headers = headers or []
81+
self.exclude_headers = exclude_headers or []
82+
self.exclude_headers = [x.lower() for x in self.exclude_headers]
8183

8284
if self.scheme == "wss":
8385
# Prevent check_hostname requires server_hostname (ref #187)
@@ -211,7 +213,7 @@ def connect(self):
211213
# default port is now 443; upgrade self.sender to send ssl
212214
self.sock = ssl.wrap_socket(self.sock, **self.ssl_options)
213215
self._is_secure = True
214-
216+
215217
self.sock.connect(self.bind_addr)
216218

217219
self._write(self.handshake_request)
@@ -257,14 +259,15 @@ def handshake_headers(self):
257259
('Sec-WebSocket-Key', self.key.decode('utf-8')),
258260
('Sec-WebSocket-Version', str(max(WS_VERSION)))
259261
]
260-
262+
261263
if self.protocols:
262264
headers.append(('Sec-WebSocket-Protocol', ','.join(self.protocols)))
263265

264266
if self.extra_headers:
265267
headers.extend(self.extra_headers)
266268

267-
if not any(x for x in headers if x[0].lower() == 'origin'):
269+
if not any(x for x in headers if x[0].lower() == 'origin') and \
270+
'origin' not in self.exclude_headers:
268271

269272
scheme, url = self.url.split(":", 1)
270273
parsed = urlsplit(url, scheme="http")
@@ -277,6 +280,8 @@ def handshake_headers(self):
277280
origin = origin + ':' + str(parsed.port)
278281
headers.append(('Origin', origin))
279282

283+
headers = [x for x in headers if x[0].lower() not in self.exclude_headers]
284+
280285
return headers
281286

282287
@property

0 commit comments

Comments
 (0)