Skip to content

Commit 4af0c2f

Browse files
[Task] A generic payload based work abstraction (#1057)
* Refactor into an internal task submodule of work * As context managers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add missing license Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d616fc8 commit 4af0c2f

File tree

8 files changed

+212
-98
lines changed

8 files changed

+212
-98
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,8 @@ Start `proxy.py` as:
710710
--plugins proxy.plugin.CacheResponsesPlugin
711711
```
712712

713+
You may also use the `--cache-requests` flag to enable request packet caching for inspection.
714+
713715
Verify using `curl -v -x localhost:8899 http://httpbin.org/get`:
714716

715717
```console

examples/task.py

Lines changed: 37 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -8,116 +8,64 @@
88
:copyright: (c) 2013-present by Abhinav Singh and contributors.
99
:license: BSD, see LICENSE for more details.
1010
"""
11-
import time
11+
import sys
1212
import argparse
13-
import threading
14-
import multiprocessing
15-
from typing import Any
1613

17-
from proxy.core.work import (
18-
Work, ThreadlessPool, BaseLocalExecutor, BaseRemoteExecutor,
19-
)
14+
from proxy.core.work import ThreadlessPool
2015
from proxy.common.flag import FlagParser
21-
from proxy.common.backports import NonBlockingQueue
22-
23-
24-
class Task:
25-
"""This will be our work object."""
26-
27-
def __init__(self, payload: bytes) -> None:
28-
self.payload = payload
29-
print(payload)
30-
31-
32-
class TaskWork(Work[Task]):
33-
"""This will be our handler class, created for each received work."""
34-
35-
@staticmethod
36-
def create(*args: Any) -> Task:
37-
"""Work core doesn't know how to create work objects for us, so
38-
we must provide an implementation of create method here."""
39-
return Task(*args)
40-
41-
42-
class LocalTaskExecutor(BaseLocalExecutor):
43-
"""We'll define a local executor which is capable of receiving
44-
log lines over a non blocking queue."""
45-
46-
def work(self, *args: Any) -> None:
47-
task_id = int(time.time())
48-
uid = '%s-%s' % (self.iid, task_id)
49-
self.works[task_id] = self.create(uid, *args)
50-
51-
52-
class RemoteTaskExecutor(BaseRemoteExecutor):
53-
54-
def work(self, *args: Any) -> None:
55-
task_id = int(time.time())
56-
uid = '%s-%s' % (self.iid, task_id)
57-
self.works[task_id] = self.create(uid, *args)
58-
59-
60-
def start_local(flags: argparse.Namespace) -> None:
61-
work_queue = NonBlockingQueue()
62-
executor = LocalTaskExecutor(iid=1, work_queue=work_queue, flags=flags)
16+
from proxy.core.work.task import (
17+
RemoteTaskExecutor, ThreadedTaskExecutor, SingleProcessTaskExecutor,
18+
)
6319

64-
t = threading.Thread(target=executor.run)
65-
t.daemon = True
66-
t.start()
6720

68-
try:
21+
def start_local_thread(flags: argparse.Namespace) -> None:
22+
with ThreadedTaskExecutor(flags=flags) as thread:
6923
i = 0
7024
while True:
71-
work_queue.put(('%d' % i).encode('utf-8'))
25+
thread.executor.work_queue.put(('%d' % i).encode('utf-8'))
7226
i += 1
73-
except KeyboardInterrupt:
74-
pass
75-
finally:
76-
executor.running.set()
77-
t.join()
7827

7928

80-
def start_remote(flags: argparse.Namespace) -> None:
81-
pipe = multiprocessing.Pipe()
82-
work_queue = pipe[0]
83-
executor = RemoteTaskExecutor(iid=1, work_queue=pipe[1], flags=flags)
29+
def start_remote_process(flags: argparse.Namespace) -> None:
30+
with SingleProcessTaskExecutor(flags=flags) as process:
31+
i = 0
32+
while True:
33+
process.work_queue.send(('%d' % i).encode('utf-8'))
34+
i += 1
8435

85-
p = multiprocessing.Process(target=executor.run)
86-
p.daemon = True
87-
p.start()
8836

89-
try:
37+
def start_remote_pool(flags: argparse.Namespace) -> None:
38+
with ThreadlessPool(flags=flags, executor_klass=RemoteTaskExecutor) as pool:
9039
i = 0
9140
while True:
41+
work_queue = pool.work_queues[i % flags.num_workers]
9242
work_queue.send(('%d' % i).encode('utf-8'))
9343
i += 1
94-
except KeyboardInterrupt:
95-
pass
96-
finally:
97-
executor.running.set()
98-
p.join()
9944

10045

101-
def start_remote_pool(flags: argparse.Namespace) -> None:
102-
with ThreadlessPool(flags=flags, executor_klass=RemoteTaskExecutor) as pool:
103-
try:
104-
i = 0
105-
while True:
106-
work_queue = pool.work_queues[i % flags.num_workers]
107-
work_queue.send(('%d' % i).encode('utf-8'))
108-
i += 1
109-
except KeyboardInterrupt:
110-
pass
46+
def main() -> None:
47+
try:
48+
flags = FlagParser.initialize(
49+
sys.argv[2:] + ['--disable-http-proxy'],
50+
work_klass='proxy.core.work.task.TaskHandler',
51+
)
52+
globals()['start_%s' % sys.argv[1]](flags)
53+
except KeyboardInterrupt:
54+
pass
11155

11256

11357
# TODO: TaskWork, LocalTaskExecutor, RemoteTaskExecutor
11458
# should not be needed, abstract those pieces out in the core
11559
# for stateless tasks.
11660
if __name__ == '__main__':
117-
flags = FlagParser.initialize(
118-
['--disable-http-proxy'],
119-
work_klass=TaskWork,
120-
)
121-
start_remote_pool(flags)
122-
# start_remote(flags)
123-
# start_local(flags)
61+
if len(sys.argv) < 2:
62+
print(
63+
'\n'.join([
64+
'Usage:',
65+
' %s <execution-mode>' % sys.argv[0],
66+
' execution-mode can be one of the following:',
67+
' "remote_pool", "remote_process", "local_thread"',
68+
]),
69+
)
70+
sys.exit(1)
71+
main()

proxy/core/base/tcp_upstream.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,18 @@ async def read_from_descriptors(self, r: Readables) -> bool:
7575
self.upstream.connection.fileno() in r:
7676
try:
7777
raw = self.upstream.recv(self.server_recvbuf_size)
78-
if raw is not None:
79-
self.total_size += len(raw)
80-
self.handle_upstream_data(raw)
81-
else:
78+
if raw is None: # pragma: no cover
8279
# Tear down because upstream proxy closed the connection
8380
return True
84-
except TimeoutError:
81+
self.total_size += len(raw)
82+
self.handle_upstream_data(raw)
83+
except TimeoutError: # pragma: no cover
8584
logger.info('Upstream recv timeout error')
8685
return True
87-
except ssl.SSLWantReadError:
86+
except ssl.SSLWantReadError: # pragma: no cover
8887
logger.info('Upstream SSLWantReadError, will retry')
8988
return False
90-
except ConnectionResetError:
89+
except ConnectionResetError: # pragma: no cover
9190
logger.debug('Connection reset by upstream')
9291
return True
9392
return False
@@ -98,10 +97,10 @@ async def write_to_descriptors(self, w: Writables) -> bool:
9897
self.upstream.has_buffer():
9998
try:
10099
self.upstream.flush()
101-
except ssl.SSLWantWriteError:
100+
except ssl.SSLWantWriteError: # pragma: no cover
102101
logger.info('Upstream SSLWantWriteError, will retry')
103102
return False
104-
except BrokenPipeError:
103+
except BrokenPipeError: # pragma: no cover
105104
logger.debug('BrokenPipeError when flushing to upstream')
106105
return True
107106
return False

proxy/core/work/task/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
proxy.py
4+
~~~~~~~~
5+
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
6+
Network monitoring, controls & Application development, testing, debugging.
7+
8+
:copyright: (c) 2013-present by Abhinav Singh and contributors.
9+
:license: BSD, see LICENSE for more details.
10+
"""
11+
from .task import Task
12+
from .local import LocalTaskExecutor, ThreadedTaskExecutor
13+
from .remote import RemoteTaskExecutor, SingleProcessTaskExecutor
14+
from .handler import TaskHandler
15+
16+
17+
__all__ = [
18+
'Task',
19+
'TaskHandler',
20+
'LocalTaskExecutor',
21+
'ThreadedTaskExecutor',
22+
'RemoteTaskExecutor',
23+
'SingleProcessTaskExecutor',
24+
]

proxy/core/work/task/handler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
proxy.py
4+
~~~~~~~~
5+
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
6+
Network monitoring, controls & Application development, testing, debugging.
7+
8+
:copyright: (c) 2013-present by Abhinav Singh and contributors.
9+
:license: BSD, see LICENSE for more details.
10+
"""
11+
from typing import Any
12+
13+
from .task import Task
14+
from ..work import Work
15+
16+
17+
class TaskHandler(Work[Task]):
18+
"""Task handler."""
19+
20+
@staticmethod
21+
def create(*args: Any) -> Task:
22+
"""Work core doesn't know how to create work objects for us.
23+
Example, for task module scenario, it doesn't know how to create
24+
Task objects for us."""
25+
return Task(*args)

proxy/core/work/task/local.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
proxy.py
4+
~~~~~~~~
5+
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
6+
Network monitoring, controls & Application development, testing, debugging.
7+
8+
:copyright: (c) 2013-present by Abhinav Singh and contributors.
9+
:license: BSD, see LICENSE for more details.
10+
"""
11+
import time
12+
import uuid
13+
import threading
14+
from typing import Any
15+
16+
from ..local import BaseLocalExecutor
17+
from ....common.backports import NonBlockingQueue
18+
19+
20+
class LocalTaskExecutor(BaseLocalExecutor):
21+
"""We'll define a local executor which is capable of receiving
22+
log lines over a non blocking queue."""
23+
24+
def work(self, *args: Any) -> None:
25+
task_id = int(time.time())
26+
uid = '%s-%s' % (self.iid, task_id)
27+
self.works[task_id] = self.create(uid, *args)
28+
29+
30+
class ThreadedTaskExecutor(threading.Thread):
31+
32+
def __init__(self, **kwargs: Any) -> None:
33+
super().__init__()
34+
self.daemon = True
35+
self.executor = LocalTaskExecutor(
36+
iid=uuid.uuid4().hex,
37+
work_queue=NonBlockingQueue(),
38+
**kwargs,
39+
)
40+
41+
def __enter__(self) -> 'ThreadedTaskExecutor':
42+
self.start()
43+
return self
44+
45+
def __exit__(self, *args: Any) -> None:
46+
self.executor.running.set()
47+
self.join()
48+
49+
def run(self) -> None:
50+
self.executor.run()

proxy/core/work/task/remote.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
proxy.py
4+
~~~~~~~~
5+
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
6+
Network monitoring, controls & Application development, testing, debugging.
7+
8+
:copyright: (c) 2013-present by Abhinav Singh and contributors.
9+
:license: BSD, see LICENSE for more details.
10+
"""
11+
import time
12+
import uuid
13+
import multiprocessing
14+
from typing import Any
15+
16+
from ..remote import BaseRemoteExecutor
17+
18+
19+
class RemoteTaskExecutor(BaseRemoteExecutor):
20+
21+
def work(self, *args: Any) -> None:
22+
task_id = int(time.time())
23+
uid = '%s-%s' % (self.iid, task_id)
24+
self.works[task_id] = self.create(uid, *args)
25+
26+
27+
class SingleProcessTaskExecutor(multiprocessing.Process):
28+
29+
def __init__(self, **kwargs: Any) -> None:
30+
super().__init__()
31+
self.daemon = True
32+
self.work_queue, remote = multiprocessing.Pipe()
33+
self.executor = RemoteTaskExecutor(
34+
iid=uuid.uuid4().hex,
35+
work_queue=remote,
36+
**kwargs,
37+
)
38+
39+
def __enter__(self) -> 'SingleProcessTaskExecutor':
40+
self.start()
41+
return self
42+
43+
def __exit__(self, *args: Any) -> None:
44+
self.executor.running.set()
45+
self.join()
46+
47+
def run(self) -> None:
48+
self.executor.run()

proxy/core/work/task/task.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
proxy.py
4+
~~~~~~~~
5+
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
6+
Network monitoring, controls & Application development, testing, debugging.
7+
8+
:copyright: (c) 2013-present by Abhinav Singh and contributors.
9+
:license: BSD, see LICENSE for more details.
10+
"""
11+
12+
13+
class Task:
14+
"""Task object which known how to process the payload."""
15+
16+
def __init__(self, payload: bytes) -> None:
17+
self.payload = payload
18+
print(payload)

0 commit comments

Comments
 (0)