Skip to content

Commit 85c30bc

Browse files
committed
Allow configuring custom retry_config for arrow client
1 parent 78c52b2 commit 85c30bc

File tree

5 files changed

+123
-63
lines changed

5 files changed

+123
-63
lines changed

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from pandas import DataFrame
77

8+
from graphdatascience.retry_utils.retry_config import RetryConfig
9+
810
from ..call_parameters import CallParameters
911
from ..query_runner.arrow_info import ArrowInfo
1012
from ..server_version.server_version import ServerVersion
@@ -24,6 +26,7 @@ def create(
2426
disable_server_verification: bool = False,
2527
tls_root_certs: Optional[bytes] = None,
2628
connection_string_override: Optional[str] = None,
29+
retry_config: Optional[RetryConfig] = None,
2730
) -> ArrowQueryRunner:
2831
if not arrow_info.enabled:
2932
raise ValueError("Arrow is not enabled on the server")
@@ -35,6 +38,7 @@ def create(
3538
disable_server_verification,
3639
tls_root_certs,
3740
connection_string_override,
41+
retry_config=retry_config,
3842
)
3943

4044
return ArrowQueryRunner(gds_arrow_client, fallback_query_runner, fallback_query_runner.server_version())

graphdatascience/query_runner/gds_arrow_client.py

Lines changed: 95 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ClientMiddleware,
2020
ClientMiddlewareFactory,
2121
FlightDescriptor,
22+
FlightInternalError,
2223
FlightMetadataReader,
2324
FlightStreamWriter,
2425
FlightTimedOutError,
@@ -34,15 +35,14 @@
3435
wait_exponential,
3536
)
3637

38+
from graphdatascience.retry_utils.retry_config import RetryConfig
3739
from graphdatascience.retry_utils.retry_utils import before_log
3840

3941
from ..semantic_version.semantic_version import SemanticVersion
4042
from ..version import __version__
4143
from .arrow_endpoint_version import ArrowEndpointVersion
4244
from .arrow_info import ArrowInfo
4345

44-
_arrow_client_logger = logging.getLogger("gds_arrow_client")
45-
4646

4747
class GdsArrowClient:
4848
@staticmethod
@@ -53,6 +53,7 @@ def create(
5353
disable_server_verification: bool = False,
5454
tls_root_certs: Optional[bytes] = None,
5555
connection_string_override: Optional[str] = None,
56+
retry_config: Optional[RetryConfig] = None,
5657
) -> GdsArrowClient:
5758
connection_string: str
5859
if connection_string_override is not None:
@@ -64,8 +65,20 @@ def create(
6465

6566
arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.versions)
6667

68+
if retry_config is None:
69+
retry_config = RetryConfig(
70+
retry=retry_any(
71+
retry_if_exception_type(FlightTimedOutError),
72+
retry_if_exception_type(FlightUnavailableError),
73+
retry_if_exception_type(FlightInternalError),
74+
),
75+
stop=(stop_after_delay(10) | stop_after_attempt(5)),
76+
wait=wait_exponential(multiplier=1, min=1, max=10),
77+
)
78+
6779
return GdsArrowClient(
6880
host,
81+
retry_config,
6982
int(port),
7083
auth,
7184
encrypted,
@@ -77,6 +90,7 @@ def create(
7790
def __init__(
7891
self,
7992
host: str,
93+
retry_config: RetryConfig,
8094
port: int = 8491,
8195
auth: Optional[tuple[str, str]] = None,
8296
encrypted: bool = False,
@@ -105,6 +119,8 @@ def __init__(
105119
The version of the Arrow endpoint to use (default is ArrowEndpointVersion.V1)
106120
user_agent: Optional[str]
107121
The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION])
122+
retry_config: Optional[RetryConfig]
123+
The retry configuration to use for the Arrow requests send by the client.
108124
"""
109125
self._arrow_endpoint_version = arrow_endpoint_version
110126
self._host = host
@@ -114,6 +130,8 @@ def __init__(
114130
self._disable_server_verification = disable_server_verification
115131
self._tls_root_certs = tls_root_certs
116132
self._user_agent = user_agent
133+
self._retry_config = retry_config
134+
self._logger = logging.getLogger("gds_arrow_client")
117135

118136
if auth:
119137
self._auth_middleware = AuthMiddleware(auth)
@@ -151,13 +169,6 @@ def connection_info(self) -> tuple[str, int]:
151169
"""
152170
return self._host, self._port
153171

154-
@retry(
155-
reraise=True,
156-
before=before_log("Request token", _arrow_client_logger, logging.DEBUG),
157-
retry=retry_any(retry_if_exception_type(FlightTimedOutError), retry_if_exception_type(FlightUnavailableError)),
158-
stop=(stop_after_delay(10) | stop_after_attempt(5)),
159-
wait=wait_exponential(multiplier=1, min=1, max=10),
160-
)
161172
def request_token(self) -> Optional[str]:
162173
"""
163174
Requests a token from the server and returns it.
@@ -167,9 +178,21 @@ def request_token(self) -> Optional[str]:
167178
Optional[str]
168179
a token from the server and returns it.
169180
"""
170-
if self._auth:
181+
182+
@retry(
183+
reraise=True,
184+
before=before_log("Request token", self._logger, logging.DEBUG),
185+
retry=self._retry_config.retry,
186+
stop=self._retry_config.stop,
187+
wait=self._retry_config.wait,
188+
)
189+
def auth_with_retry() -> None:
171190
client = self._client()
172-
client.authenticate_basic_token(self._auth[0], self._auth[1])
191+
if self._auth:
192+
client.authenticate_basic_token(self._auth[0], self._auth[1])
193+
194+
if self._auth:
195+
auth_with_retry()
173196
return self._auth_middleware.token()
174197
else:
175198
return "IGNORED"
@@ -220,7 +243,7 @@ def get_node_properties(
220243
if node_labels:
221244
config["node_labels"] = node_labels
222245

223-
return self._do_get(database, graph_name, proc, concurrency, config)
246+
return self._do_get_with_retry(database, graph_name, proc, concurrency, config)
224247

225248
def get_node_labels(self, graph_name: str, database: str, concurrency: Optional[int] = None) -> pandas.DataFrame:
226249
"""
@@ -240,7 +263,7 @@ def get_node_labels(self, graph_name: str, database: str, concurrency: Optional[
240263
DataFrame
241264
The requested nodes as a DataFrame
242265
"""
243-
return self._do_get(database, graph_name, "gds.graph.nodeLabels.stream", concurrency, {})
266+
return self._do_get_with_retry(database, graph_name, "gds.graph.nodeLabels.stream", concurrency, {})
244267

245268
def get_relationships(
246269
self, graph_name: str, database: str, relationship_types: list[str], concurrency: Optional[int] = None
@@ -264,7 +287,7 @@ def get_relationships(
264287
DataFrame
265288
The requested relationships as a DataFrame
266289
"""
267-
return self._do_get(
290+
return self._do_get_with_retry(
268291
database,
269292
graph_name,
270293
"gds.graph.relationships.stream",
@@ -312,7 +335,7 @@ def get_relationship_properties(
312335
if relationship_types:
313336
config["relationship_types"] = relationship_types
314337

315-
return self._do_get(database, graph_name, proc, concurrency, config)
338+
return self._do_get_with_retry(database, graph_name, proc, concurrency, config)
316339

317340
def create_graph(
318341
self,
@@ -598,40 +621,31 @@ def _client(self) -> flight.FlightClient:
598621
self._flight_client = self._instantiate_flight_client()
599622
return self._flight_client
600623

601-
@retry(
602-
reraise=True,
603-
before=before_log("Send action", _arrow_client_logger, logging.DEBUG),
604-
retry=retry_any(retry_if_exception_type(FlightTimedOutError), retry_if_exception_type(FlightUnavailableError)),
605-
stop=(stop_after_delay(10) | stop_after_attempt(5)),
606-
wait=wait_exponential(multiplier=1, min=1, max=10),
607-
)
608624
def _send_action(self, action_type: str, meta_data: dict[str, Any]) -> dict[str, Any]:
609625
action_type = self._versioned_action_type(action_type)
626+
client = self._client()
610627

611-
try:
612-
client = self._client()
613-
result = client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
628+
@retry(
629+
reraise=True,
630+
before=before_log("Send action", self._logger, logging.DEBUG),
631+
retry=self._retry_config.retry,
632+
stop=self._retry_config.stop,
633+
wait=self._retry_config.wait,
634+
)
635+
def send_with_retry() -> dict[str, Any]:
636+
try:
637+
result = client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
614638

615-
# Consume result fully to sanity check and avoid cancelled streams
616-
collected_result = list(result)
617-
assert len(collected_result) == 1
639+
# Consume result fully to sanity check and avoid cancelled streams
640+
collected_result = list(result)
641+
assert len(collected_result) == 1
618642

619-
return json.loads(collected_result[0].body.to_pybytes().decode()) # type: ignore
620-
except Exception as e:
621-
self.handle_flight_error(e)
622-
raise e # unreachable
623-
624-
@retry(
625-
reraise=True,
626-
before=before_log("Do put", _arrow_client_logger, logging.DEBUG),
627-
retry=retry_any(retry_if_exception_type(FlightTimedOutError), retry_if_exception_type(FlightUnavailableError)),
628-
stop=(stop_after_delay(10) | stop_after_attempt(5)),
629-
wait=wait_exponential(multiplier=1, min=1, max=10),
630-
)
631-
def _safe_do_put(
632-
self, upload_descriptor: FlightDescriptor, schema: Schema
633-
) -> tuple[FlightStreamWriter, FlightMetadataReader]:
634-
return self._client().do_put(upload_descriptor, schema) # type: ignore
643+
return json.loads(collected_result[0].body.to_pybytes().decode()) # type: ignore
644+
except Exception as e:
645+
self.handle_flight_error(e)
646+
raise e # unreachable
647+
648+
return send_with_retry()
635649

636650
def _upload_data(
637651
self,
@@ -651,18 +665,26 @@ def _upload_data(
651665
flight_descriptor = self._versioned_flight_descriptor({"name": graph_name, "entity_type": entity_type})
652666
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
653667

654-
put_stream, ack_stream = self._safe_do_put(upload_descriptor, batches[0].schema)
668+
@retry(
669+
reraise=True,
670+
before=before_log("Do put", self._logger, logging.DEBUG),
671+
retry=self._retry_config.retry,
672+
stop=self._retry_config.stop,
673+
wait=self._retry_config.wait,
674+
)
675+
def safe_do_put(
676+
upload_descriptor: FlightDescriptor, schema: Schema
677+
) -> tuple[FlightStreamWriter, FlightMetadataReader]:
678+
return self._client().do_put(upload_descriptor, schema) # type: ignore
679+
680+
put_stream, ack_stream = safe_do_put(upload_descriptor, batches[0].schema)
655681

656682
@retry(
657683
reraise=True,
658-
before=before_log("Upload batch", _arrow_client_logger, logging.DEBUG),
659-
stop=(stop_after_delay(10) | stop_after_attempt(5)),
660-
wait=wait_exponential(multiplier=1, min=1, max=10),
661-
retry=(
662-
retry_if_exception_type(flight.FlightUnavailableError)
663-
| retry_if_exception_type(flight.FlightTimedOutError)
664-
| retry_if_exception_type(flight.FlightInternalError)
665-
),
684+
before=before_log("Upload batch", self._logger, logging.DEBUG),
685+
retry=self._retry_config.retry,
686+
stop=self._retry_config.stop,
687+
wait=self._retry_config.wait,
666688
)
667689
def upload_batch(p: RecordBatch) -> None:
668690
put_stream.write_batch(p)
@@ -676,13 +698,26 @@ def upload_batch(p: RecordBatch) -> None:
676698
except Exception as e:
677699
GdsArrowClient.handle_flight_error(e)
678700

679-
@retry(
680-
reraise=True,
681-
before=before_log("Do get", _arrow_client_logger, logging.DEBUG),
682-
retry=retry_any(retry_if_exception_type(FlightTimedOutError), retry_if_exception_type(FlightUnavailableError)),
683-
stop=(stop_after_delay(10) | stop_after_attempt(5)),
684-
wait=wait_exponential(multiplier=1, min=1, max=10),
685-
)
701+
def _do_get_with_retry(
702+
self,
703+
database: str,
704+
graph_name: str,
705+
procedure_name: str,
706+
concurrency: Optional[int],
707+
configuration: dict[str, Any],
708+
) -> pandas.DataFrame:
709+
@retry(
710+
reraise=True,
711+
before=before_log("Do get", self._logger, logging.DEBUG),
712+
retry=self._retry_config.retry,
713+
stop=self._retry_config.stop,
714+
wait=self._retry_config.wait,
715+
)
716+
def safe_do_get() -> pandas.DataFrame:
717+
return self._do_get(database, graph_name, procedure_name, concurrency, configuration)
718+
719+
return safe_do_get()
720+
686721
def _do_get(
687722
self,
688723
database: str,
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from dataclasses import dataclass
2+
3+
from tenacity.retry import retry_base
4+
from tenacity.stop import stop_base
5+
from tenacity.wait import wait_base
6+
7+
8+
@dataclass(frozen=True, repr=True)
9+
class RetryConfig:
10+
stop: stop_base
11+
wait: wait_base
12+
retry: retry_base

graphdatascience/tests/unit/test_arrow_runner.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import pytest
22
from pyarrow.flight import FlightUnavailableError
3+
from tenacity import retry_any, stop_after_attempt, wait_fixed
34

45
from graphdatascience.query_runner.arrow_info import ArrowInfo
56
from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner
7+
from graphdatascience.retry_utils.retry_config import RetryConfig
68
from graphdatascience.server_version.server_version import ServerVersion
79

810
from ...query_runner.arrow_endpoint_version import ArrowEndpointVersion
@@ -14,7 +16,13 @@ def test_create(runner: CollectingQueryRunner) -> None:
1416
arrow_info = ArrowInfo(
1517
listenAddress="localhost:1234", enabled=True, running=True, versions=[ArrowEndpointVersion.V1.version()]
1618
)
17-
arrow_runner = ArrowQueryRunner.create(runner, arrow_info)
19+
retry_config = RetryConfig(
20+
retry=retry_any(),
21+
stop=(stop_after_attempt(1)),
22+
wait=wait_fixed(0),
23+
)
24+
25+
arrow_runner = ArrowQueryRunner.create(runner, arrow_info, retry_config=retry_config)
1826

1927
assert isinstance(arrow_runner, ArrowQueryRunner)
2028

graphdatascience/tests/unit/test_gds_arrow_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Ticket,
1515
)
1616

17+
from graphdatascience.query_runner.arrow_info import ArrowInfo
1718
from graphdatascience.query_runner.gds_arrow_client import AuthMiddleware, GdsArrowClient
1819

1920
ActionParam = Union[str, tuple[str, Any], Action]
@@ -120,13 +121,13 @@ def flaky_flight_server() -> Generator[None, FlakyFlightServer, None]:
120121

121122
@pytest.fixture()
122123
def flight_client(flight_server: FlightServer) -> Generator[GdsArrowClient, None, None]:
123-
with GdsArrowClient("localhost", flight_server.port) as client:
124+
with GdsArrowClient.create(ArrowInfo(f"localhost:{flight_server.port}", True, True, ["v1"])) as client:
124125
yield client
125126

126127

127128
@pytest.fixture()
128129
def flaky_flight_client(flaky_flight_server: FlakyFlightServer) -> Generator[GdsArrowClient, None, None]:
129-
with GdsArrowClient("localhost", flaky_flight_server.port) as client:
130+
with GdsArrowClient.create(ArrowInfo(f"localhost:{flaky_flight_server.port}", True, True, ["v1"])) as client:
130131
yield client
131132

132133

0 commit comments

Comments
 (0)