Skip to content

Commit d2ae428

Browse files
committed
Retry on Arrow unavailable and timeout exceptions
1 parent 40aed37 commit d2ae428

File tree

3 files changed

+147
-12
lines changed

3 files changed

+147
-12
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
* Reduce calls to the Aura API during GdsSessions::get_or_create.
1616
* Improve error message when a query is interrupted by a signal (SIGINT or SIGTERM).
1717
* Improve error message if session is expired.
18+
* Improve robustness of Arrow client against connection errors such as `FlightUnavailableError` and `FlightTimedOutError`.
1819

1920

2021
## Other changes

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,27 @@
1212
import pandas
1313
import pyarrow
1414
from neo4j.exceptions import ClientError
15-
from pyarrow import Array, ChunkedArray, DictionaryArray, RecordBatch, Table, chunked_array, flight
15+
from pyarrow import Array, ChunkedArray, DictionaryArray, RecordBatch, Schema, Table, chunked_array, flight
1616
from pyarrow import __version__ as arrow_version
17-
from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
17+
from pyarrow.flight import (
18+
ClientMiddleware,
19+
ClientMiddlewareFactory,
20+
FlightDescriptor,
21+
FlightMetadataReader,
22+
FlightStreamWriter,
23+
FlightTimedOutError,
24+
FlightUnavailableError,
25+
)
1826
from pyarrow.types import is_dictionary
19-
from tenacity import retry, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential
27+
from tenacity import (
28+
retry,
29+
retry_any,
30+
retry_if_exception_type,
31+
stop_after_attempt,
32+
stop_after_delay,
33+
wait_exponential,
34+
wait_fixed,
35+
)
2036

2137
from ..semantic_version.semantic_version import SemanticVersion
2238
from ..version import __version__
@@ -131,6 +147,11 @@ def connection_info(self) -> tuple[str, int]:
131147
"""
132148
return self._host, self._port
133149

150+
@retry(
151+
retry=retry_any(retry_if_exception_type(FlightTimedOutError), retry_if_exception_type(FlightUnavailableError)),
152+
stop=stop_after_attempt(3),
153+
wait=wait_fixed(1),
154+
)
134155
def request_token(self) -> Optional[str]:
135156
"""
136157
Requests a token from the server and returns it.
@@ -571,6 +592,11 @@ def _client(self) -> flight.FlightClient:
571592
self._flight_client = self._instantiate_flight_client()
572593
return self._flight_client
573594

595+
@retry(
596+
retry=retry_any(retry_if_exception_type(FlightTimedOutError), retry_if_exception_type(FlightUnavailableError)),
597+
stop=(stop_after_delay(10) | stop_after_attempt(5)),
598+
wait=wait_exponential(multiplier=1, min=1, max=10),
599+
)
574600
def _send_action(self, action_type: str, meta_data: dict[str, Any]) -> dict[str, Any]:
575601
action_type = self._versioned_action_type(action_type)
576602

@@ -587,6 +613,16 @@ def _send_action(self, action_type: str, meta_data: dict[str, Any]) -> dict[str,
587613
self.handle_flight_error(e)
588614
raise e # unreachable
589615

616+
@retry(
617+
retry=retry_any(retry_if_exception_type(FlightTimedOutError), retry_if_exception_type(FlightUnavailableError)),
618+
stop=(stop_after_delay(10) | stop_after_attempt(5)),
619+
wait=wait_exponential(multiplier=1, min=1, max=10),
620+
)
621+
def _safe_do_put(
622+
self, upload_descriptor: FlightDescriptor, schema: Schema
623+
) -> tuple[FlightStreamWriter, FlightMetadataReader]:
624+
return self._client().do_put(upload_descriptor, schema) # type: ignore
625+
590626
def _upload_data(
591627
self,
592628
graph_name: str,
@@ -605,8 +641,7 @@ def _upload_data(
605641
flight_descriptor = self._versioned_flight_descriptor({"name": graph_name, "entity_type": entity_type})
606642
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
607643

608-
client = self._client()
609-
put_stream, ack_stream = client.do_put(upload_descriptor, batches[0].schema)
644+
put_stream, ack_stream = self._safe_do_put(upload_descriptor, batches[0].schema)
610645

611646
@retry(
612647
stop=(stop_after_delay(10) | stop_after_attempt(5)),
@@ -629,6 +664,11 @@ def upload_batch(p: RecordBatch) -> None:
629664
except Exception as e:
630665
GdsArrowClient.handle_flight_error(e)
631666

667+
@retry(
668+
retry=retry_any(retry_if_exception_type(FlightTimedOutError), retry_if_exception_type(FlightUnavailableError)),
669+
stop=(stop_after_delay(10) | stop_after_attempt(5)),
670+
wait=wait_exponential(multiplier=1, min=1, max=10),
671+
)
632672
def _do_get(
633673
self,
634674
database: str,

graphdatascience/tests/unit/test_gds_arrow_client.py

Lines changed: 101 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,22 @@
44

55
import pyarrow as pa
66
import pytest
7-
from pyarrow import flight
87
from pyarrow._flight import GeneratorStream
9-
from pyarrow.flight import Action, Ticket
8+
from pyarrow.flight import (
9+
Action,
10+
FlightServerBase,
11+
FlightServerError,
12+
FlightTimedOutError,
13+
FlightUnavailableError,
14+
Ticket,
15+
)
1016

1117
from graphdatascience.query_runner.gds_arrow_client import AuthMiddleware, GdsArrowClient
1218

1319
ActionParam = Union[str, tuple[str, Any], Action]
1420

1521

16-
class FlightServer(flight.FlightServerBase): # type: ignore
22+
class FlightServer(FlightServerBase): # type: ignore
1723
def __init__(self, location: str = "grpc://0.0.0.0:0", **kwargs: dict[str, Any]) -> None:
1824
super(FlightServer, self).__init__(location, **kwargs)
1925
self._location: str = location
@@ -49,18 +55,81 @@ def do_action(self, context: Any, action: ActionParam) -> list[bytes]:
4955
return [json.dumps(response).encode("utf-8")]
5056

5157

58+
class FlakyFlightServer(FlightServerBase): # type: ignore
59+
def __init__(self, location: str = "grpc://0.0.0.0:0", **kwargs: dict[str, Any]) -> None:
60+
super(FlakyFlightServer, self).__init__(location, **kwargs)
61+
self._location: str = location
62+
self._actions: list[ActionParam] = []
63+
self._tickets: list[Ticket] = []
64+
self._expected_failures = [
65+
FlightUnavailableError("Flight server is unavailable", "some reason"),
66+
FlightTimedOutError("Time out for some reason", "still timed out"),
67+
]
68+
self._expected_retries = len(self._expected_failures) + 1
69+
70+
def expected_retries(self) -> int:
71+
return self._expected_retries
72+
73+
def do_get(self, context: Any, ticket: Ticket) -> GeneratorStream:
74+
self._tickets.append(ticket)
75+
76+
if len(self._expected_failures) > 0:
77+
raise self._expected_failures.pop()
78+
79+
table = pa.Table.from_pydict({"ids": [42, 1337, 1234]})
80+
return GeneratorStream(schema=table.schema, generator=table.to_batches())
81+
82+
def do_action(self, context: Any, action: ActionParam) -> list[bytes]:
83+
self._actions.append(action)
84+
85+
if len(self._expected_failures) > 0:
86+
raise self._expected_failures.pop()
87+
88+
if isinstance(action, Action):
89+
actionType = action.type
90+
elif isinstance(action, tuple):
91+
actionType = action[0]
92+
elif isinstance(action, str):
93+
actionType = action
94+
95+
response: dict[str, Any] = {}
96+
if "CREATE" in actionType:
97+
response = {"name": "g"}
98+
elif "NODE_LOAD_DONE" in actionType:
99+
response = {"name": "g", "node_count": 42}
100+
elif "RELATIONSHIP_LOAD_DONE" in actionType:
101+
response = {"name": "g", "relationship_count": 42}
102+
elif "TRIPLET_LOAD_DONE" in actionType:
103+
response = {"name": "g", "node_count": 42, "relationship_count": 1337}
104+
else:
105+
response = {}
106+
return [json.dumps(response).encode("utf-8")]
107+
108+
52109
@pytest.fixture()
53110
def flight_server() -> Generator[None, FlightServer, None]:
54111
with FlightServer() as server:
55112
yield server
56113

57114

115+
@pytest.fixture()
116+
def flaky_flight_server() -> Generator[None, FlakyFlightServer, None]:
117+
with FlakyFlightServer() as server:
118+
yield server
119+
120+
58121
@pytest.fixture()
59122
def flight_client(flight_server: FlightServer) -> Generator[GdsArrowClient, None, None]:
60123
with GdsArrowClient("localhost", flight_server.port) as client:
61124
yield client
62125

63126

127+
@pytest.fixture()
128+
def flaky_flight_client(flaky_flight_server: FlakyFlightServer) -> Generator[GdsArrowClient, None, None]:
129+
with GdsArrowClient("localhost", flaky_flight_server.port) as client:
130+
yield client
131+
132+
64133
def test_create_graph_with_defaults(flight_server: FlightServer, flight_client: GdsArrowClient) -> None:
65134
flight_client.create_graph("g", "DB")
66135
actions = flight_server._actions
@@ -87,6 +156,15 @@ def test_create_graph_with_options(flight_server: FlightServer, flight_client: G
87156
)
88157

89158

159+
def test_create_graph_with_flaky_server(
160+
flaky_flight_server: FlakyFlightServer, flaky_flight_client: GdsArrowClient
161+
) -> None:
162+
flaky_flight_client.create_graph("g", "DB")
163+
actions = flaky_flight_server._actions
164+
assert len(actions) == flaky_flight_server.expected_retries()
165+
assert_action(actions[0], "v1/CREATE_GRAPH", {"name": "g", "database_name": "DB"})
166+
167+
90168
def test_create_graph_from_triplets_with_defaults(flight_server: FlightServer, flight_client: GdsArrowClient) -> None:
91169
flight_client.create_graph_from_triplets("g", "DB")
92170
actions = flight_server._actions
@@ -202,6 +280,22 @@ def test_get_node_property(flight_server: FlightServer, flight_client: GdsArrowC
202280
)
203281

204282

283+
def test_flakey_get_node_property(flaky_flight_server: FlakyFlightServer, flaky_flight_client: GdsArrowClient) -> None:
284+
flaky_flight_client.get_node_properties("g", "db", "id", ["Person"], concurrency=42)
285+
tickets = flaky_flight_server._tickets
286+
assert len(tickets) == flaky_flight_server.expected_retries()
287+
assert_ticket(
288+
tickets[0],
289+
{
290+
"concurrency": 42,
291+
"configuration": {"list_node_labels": False, "node_labels": ["Person"], "node_property": "id"},
292+
"database_name": "db",
293+
"graph_name": "g",
294+
"procedure_name": "gds.graph.nodeProperty.stream",
295+
},
296+
)
297+
298+
205299
def test_get_node_properties(flight_server: FlightServer, flight_client: GdsArrowClient) -> None:
206300
flight_client.get_node_properties("g", "db", ["foo", "bar"], ["Person"], list_node_labels=True, concurrency=42)
207301
tickets = flight_server._tickets
@@ -314,21 +408,21 @@ def test_auth_middleware_bad_headers() -> None:
314408

315409
def test_handle_flight_error() -> None:
316410
with pytest.raises(
317-
flight.FlightServerError,
411+
FlightServerError,
318412
match="FlightServerError: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.",
319413
):
320414
GdsArrowClient.handle_flight_error(
321-
flight.FlightServerError(
415+
FlightServerError(
322416
'FlightServerError: Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.. gRPC client debug context: UNKNOWN:Error received from peer ipv4:35.241.177.75:8491 {created_time:"2024-08-29T15:59:03.828903999+02:00", grpc_status:2, grpc_message:"org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database."}. Client context: IOError: Server never sent a data message. Detail: Internal'
323417
)
324418
)
325419

326420
with pytest.raises(
327-
flight.FlightServerError,
421+
FlightServerError,
328422
match=re.escape("FlightServerError: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"),
329423
):
330424
GdsArrowClient.handle_flight_error(
331-
flight.FlightServerError(
425+
FlightServerError(
332426
"FlightServerError: Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"
333427
)
334428
)

0 commit comments

Comments
 (0)