19
19
ClientMiddleware ,
20
20
ClientMiddlewareFactory ,
21
21
FlightDescriptor ,
22
+ FlightInternalError ,
22
23
FlightMetadataReader ,
23
24
FlightStreamWriter ,
24
25
FlightTimedOutError ,
34
35
wait_exponential ,
35
36
)
36
37
38
+ from graphdatascience .retry_utils .retry_config import RetryConfig
37
39
from graphdatascience .retry_utils .retry_utils import before_log
38
40
39
41
from ..semantic_version .semantic_version import SemanticVersion
40
42
from ..version import __version__
41
43
from .arrow_endpoint_version import ArrowEndpointVersion
42
44
from .arrow_info import ArrowInfo
43
45
44
- _arrow_client_logger = logging .getLogger ("gds_arrow_client" )
45
-
46
46
47
47
class GdsArrowClient :
48
48
@staticmethod
@@ -53,6 +53,7 @@ def create(
53
53
disable_server_verification : bool = False ,
54
54
tls_root_certs : Optional [bytes ] = None ,
55
55
connection_string_override : Optional [str ] = None ,
56
+ retry_config : Optional [RetryConfig ] = None ,
56
57
) -> GdsArrowClient :
57
58
connection_string : str
58
59
if connection_string_override is not None :
@@ -64,8 +65,20 @@ def create(
64
65
65
66
arrow_endpoint_version = ArrowEndpointVersion .from_arrow_info (arrow_info .versions )
66
67
68
+ if retry_config is None :
69
+ retry_config = RetryConfig (
70
+ retry = retry_any (
71
+ retry_if_exception_type (FlightTimedOutError ),
72
+ retry_if_exception_type (FlightUnavailableError ),
73
+ retry_if_exception_type (FlightInternalError ),
74
+ ),
75
+ stop = (stop_after_delay (10 ) | stop_after_attempt (5 )),
76
+ wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
77
+ )
78
+
67
79
return GdsArrowClient (
68
80
host ,
81
+ retry_config ,
69
82
int (port ),
70
83
auth ,
71
84
encrypted ,
@@ -77,6 +90,7 @@ def create(
77
90
def __init__ (
78
91
self ,
79
92
host : str ,
93
+ retry_config : RetryConfig ,
80
94
port : int = 8491 ,
81
95
auth : Optional [tuple [str , str ]] = None ,
82
96
encrypted : bool = False ,
@@ -105,6 +119,8 @@ def __init__(
105
119
The version of the Arrow endpoint to use (default is ArrowEndpointVersion.V1)
106
120
user_agent: Optional[str]
107
121
The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION])
122
+ retry_config: Optional[RetryConfig]
123
+ The retry configuration to use for the Arrow requests send by the client.
108
124
"""
109
125
self ._arrow_endpoint_version = arrow_endpoint_version
110
126
self ._host = host
@@ -114,6 +130,8 @@ def __init__(
114
130
self ._disable_server_verification = disable_server_verification
115
131
self ._tls_root_certs = tls_root_certs
116
132
self ._user_agent = user_agent
133
+ self ._retry_config = retry_config
134
+ self ._logger = logging .getLogger ("gds_arrow_client" )
117
135
118
136
if auth :
119
137
self ._auth_middleware = AuthMiddleware (auth )
@@ -151,13 +169,6 @@ def connection_info(self) -> tuple[str, int]:
151
169
"""
152
170
return self ._host , self ._port
153
171
154
- @retry (
155
- reraise = True ,
156
- before = before_log ("Request token" , _arrow_client_logger , logging .DEBUG ),
157
- retry = retry_any (retry_if_exception_type (FlightTimedOutError ), retry_if_exception_type (FlightUnavailableError )),
158
- stop = (stop_after_delay (10 ) | stop_after_attempt (5 )),
159
- wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
160
- )
161
172
def request_token (self ) -> Optional [str ]:
162
173
"""
163
174
Requests a token from the server and returns it.
@@ -167,9 +178,21 @@ def request_token(self) -> Optional[str]:
167
178
Optional[str]
168
179
a token from the server and returns it.
169
180
"""
170
- if self ._auth :
181
+
182
+ @retry (
183
+ reraise = True ,
184
+ before = before_log ("Request token" , self ._logger , logging .DEBUG ),
185
+ retry = self ._retry_config .retry ,
186
+ stop = self ._retry_config .stop ,
187
+ wait = self ._retry_config .wait ,
188
+ )
189
+ def auth_with_retry () -> None :
171
190
client = self ._client ()
172
- client .authenticate_basic_token (self ._auth [0 ], self ._auth [1 ])
191
+ if self ._auth :
192
+ client .authenticate_basic_token (self ._auth [0 ], self ._auth [1 ])
193
+
194
+ if self ._auth :
195
+ auth_with_retry ()
173
196
return self ._auth_middleware .token ()
174
197
else :
175
198
return "IGNORED"
@@ -220,7 +243,7 @@ def get_node_properties(
220
243
if node_labels :
221
244
config ["node_labels" ] = node_labels
222
245
223
- return self ._do_get (database , graph_name , proc , concurrency , config )
246
+ return self ._do_get_with_retry (database , graph_name , proc , concurrency , config )
224
247
225
248
def get_node_labels (self , graph_name : str , database : str , concurrency : Optional [int ] = None ) -> pandas .DataFrame :
226
249
"""
@@ -240,7 +263,7 @@ def get_node_labels(self, graph_name: str, database: str, concurrency: Optional[
240
263
DataFrame
241
264
The requested nodes as a DataFrame
242
265
"""
243
- return self ._do_get (database , graph_name , "gds.graph.nodeLabels.stream" , concurrency , {})
266
+ return self ._do_get_with_retry (database , graph_name , "gds.graph.nodeLabels.stream" , concurrency , {})
244
267
245
268
def get_relationships (
246
269
self , graph_name : str , database : str , relationship_types : list [str ], concurrency : Optional [int ] = None
@@ -264,7 +287,7 @@ def get_relationships(
264
287
DataFrame
265
288
The requested relationships as a DataFrame
266
289
"""
267
- return self ._do_get (
290
+ return self ._do_get_with_retry (
268
291
database ,
269
292
graph_name ,
270
293
"gds.graph.relationships.stream" ,
@@ -312,7 +335,7 @@ def get_relationship_properties(
312
335
if relationship_types :
313
336
config ["relationship_types" ] = relationship_types
314
337
315
- return self ._do_get (database , graph_name , proc , concurrency , config )
338
+ return self ._do_get_with_retry (database , graph_name , proc , concurrency , config )
316
339
317
340
def create_graph (
318
341
self ,
@@ -598,40 +621,31 @@ def _client(self) -> flight.FlightClient:
598
621
self ._flight_client = self ._instantiate_flight_client ()
599
622
return self ._flight_client
600
623
601
- @retry (
602
- reraise = True ,
603
- before = before_log ("Send action" , _arrow_client_logger , logging .DEBUG ),
604
- retry = retry_any (retry_if_exception_type (FlightTimedOutError ), retry_if_exception_type (FlightUnavailableError )),
605
- stop = (stop_after_delay (10 ) | stop_after_attempt (5 )),
606
- wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
607
- )
608
624
def _send_action (self , action_type : str , meta_data : dict [str , Any ]) -> dict [str , Any ]:
609
625
action_type = self ._versioned_action_type (action_type )
626
+ client = self ._client ()
610
627
611
- try :
612
- client = self ._client ()
613
- result = client .do_action (flight .Action (action_type , json .dumps (meta_data ).encode ("utf-8" )))
628
+ @retry (
629
+ reraise = True ,
630
+ before = before_log ("Send action" , self ._logger , logging .DEBUG ),
631
+ retry = self ._retry_config .retry ,
632
+ stop = self ._retry_config .stop ,
633
+ wait = self ._retry_config .wait ,
634
+ )
635
+ def send_with_retry () -> dict [str , Any ]:
636
+ try :
637
+ result = client .do_action (flight .Action (action_type , json .dumps (meta_data ).encode ("utf-8" )))
614
638
615
- # Consume result fully to sanity check and avoid cancelled streams
616
- collected_result = list (result )
617
- assert len (collected_result ) == 1
639
+ # Consume result fully to sanity check and avoid cancelled streams
640
+ collected_result = list (result )
641
+ assert len (collected_result ) == 1
618
642
619
- return json .loads (collected_result [0 ].body .to_pybytes ().decode ()) # type: ignore
620
- except Exception as e :
621
- self .handle_flight_error (e )
622
- raise e # unreachable
623
-
624
- @retry (
625
- reraise = True ,
626
- before = before_log ("Do put" , _arrow_client_logger , logging .DEBUG ),
627
- retry = retry_any (retry_if_exception_type (FlightTimedOutError ), retry_if_exception_type (FlightUnavailableError )),
628
- stop = (stop_after_delay (10 ) | stop_after_attempt (5 )),
629
- wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
630
- )
631
- def _safe_do_put (
632
- self , upload_descriptor : FlightDescriptor , schema : Schema
633
- ) -> tuple [FlightStreamWriter , FlightMetadataReader ]:
634
- return self ._client ().do_put (upload_descriptor , schema ) # type: ignore
643
+ return json .loads (collected_result [0 ].body .to_pybytes ().decode ()) # type: ignore
644
+ except Exception as e :
645
+ self .handle_flight_error (e )
646
+ raise e # unreachable
647
+
648
+ return send_with_retry ()
635
649
636
650
def _upload_data (
637
651
self ,
@@ -651,18 +665,26 @@ def _upload_data(
651
665
flight_descriptor = self ._versioned_flight_descriptor ({"name" : graph_name , "entity_type" : entity_type })
652
666
upload_descriptor = flight .FlightDescriptor .for_command (json .dumps (flight_descriptor ).encode ("utf-8" ))
653
667
654
- put_stream , ack_stream = self ._safe_do_put (upload_descriptor , batches [0 ].schema )
668
+ @retry (
669
+ reraise = True ,
670
+ before = before_log ("Do put" , self ._logger , logging .DEBUG ),
671
+ retry = self ._retry_config .retry ,
672
+ stop = self ._retry_config .stop ,
673
+ wait = self ._retry_config .wait ,
674
+ )
675
+ def safe_do_put (
676
+ upload_descriptor : FlightDescriptor , schema : Schema
677
+ ) -> tuple [FlightStreamWriter , FlightMetadataReader ]:
678
+ return self ._client ().do_put (upload_descriptor , schema ) # type: ignore
679
+
680
+ put_stream , ack_stream = safe_do_put (upload_descriptor , batches [0 ].schema )
655
681
656
682
@retry (
657
683
reraise = True ,
658
- before = before_log ("Upload batch" , _arrow_client_logger , logging .DEBUG ),
659
- stop = (stop_after_delay (10 ) | stop_after_attempt (5 )),
660
- wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
661
- retry = (
662
- retry_if_exception_type (flight .FlightUnavailableError )
663
- | retry_if_exception_type (flight .FlightTimedOutError )
664
- | retry_if_exception_type (flight .FlightInternalError )
665
- ),
684
+ before = before_log ("Upload batch" , self ._logger , logging .DEBUG ),
685
+ retry = self ._retry_config .retry ,
686
+ stop = self ._retry_config .stop ,
687
+ wait = self ._retry_config .wait ,
666
688
)
667
689
def upload_batch (p : RecordBatch ) -> None :
668
690
put_stream .write_batch (p )
@@ -676,13 +698,26 @@ def upload_batch(p: RecordBatch) -> None:
676
698
except Exception as e :
677
699
GdsArrowClient .handle_flight_error (e )
678
700
679
- @retry (
680
- reraise = True ,
681
- before = before_log ("Do get" , _arrow_client_logger , logging .DEBUG ),
682
- retry = retry_any (retry_if_exception_type (FlightTimedOutError ), retry_if_exception_type (FlightUnavailableError )),
683
- stop = (stop_after_delay (10 ) | stop_after_attempt (5 )),
684
- wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
685
- )
701
+ def _do_get_with_retry (
702
+ self ,
703
+ database : str ,
704
+ graph_name : str ,
705
+ procedure_name : str ,
706
+ concurrency : Optional [int ],
707
+ configuration : dict [str , Any ],
708
+ ) -> pandas .DataFrame :
709
+ @retry (
710
+ reraise = True ,
711
+ before = before_log ("Do get" , self ._logger , logging .DEBUG ),
712
+ retry = self ._retry_config .retry ,
713
+ stop = self ._retry_config .stop ,
714
+ wait = self ._retry_config .wait ,
715
+ )
716
+ def safe_do_get () -> pandas .DataFrame :
717
+ return self ._do_get (database , graph_name , procedure_name , concurrency , configuration )
718
+
719
+ return safe_do_get ()
720
+
686
721
def _do_get (
687
722
self ,
688
723
database : str ,
0 commit comments