4
4
5
5
import pyarrow as pa
6
6
import pytest
7
- from pyarrow import flight
8
7
from pyarrow ._flight import GeneratorStream
9
- from pyarrow .flight import Action , Ticket
8
+ from pyarrow .flight import (
9
+ Action ,
10
+ FlightServerBase ,
11
+ FlightServerError ,
12
+ FlightTimedOutError ,
13
+ FlightUnavailableError ,
14
+ Ticket ,
15
+ )
10
16
11
17
from graphdatascience .query_runner .gds_arrow_client import AuthMiddleware , GdsArrowClient
12
18
13
19
ActionParam = Union [str , tuple [str , Any ], Action ]
14
20
15
21
16
- class FlightServer (flight . FlightServerBase ): # type: ignore
22
+ class FlightServer (FlightServerBase ): # type: ignore
17
23
def __init__ (self , location : str = "grpc://0.0.0.0:0" , ** kwargs : dict [str , Any ]) -> None :
18
24
super (FlightServer , self ).__init__ (location , ** kwargs )
19
25
self ._location : str = location
@@ -49,18 +55,81 @@ def do_action(self, context: Any, action: ActionParam) -> list[bytes]:
49
55
return [json .dumps (response ).encode ("utf-8" )]
50
56
51
57
58
+ class FlakyFlightServer (FlightServerBase ): # type: ignore
59
+ def __init__ (self , location : str = "grpc://0.0.0.0:0" , ** kwargs : dict [str , Any ]) -> None :
60
+ super (FlakyFlightServer , self ).__init__ (location , ** kwargs )
61
+ self ._location : str = location
62
+ self ._actions : list [ActionParam ] = []
63
+ self ._tickets : list [Ticket ] = []
64
+ self ._expected_failures = [
65
+ FlightUnavailableError ("Flight server is unavailable" , "some reason" ),
66
+ FlightTimedOutError ("Time out for some reason" , "still timed out" ),
67
+ ]
68
+ self ._expected_retries = len (self ._expected_failures ) + 1
69
+
70
+ def expected_retries (self ) -> int :
71
+ return self ._expected_retries
72
+
73
+ def do_get (self , context : Any , ticket : Ticket ) -> GeneratorStream :
74
+ self ._tickets .append (ticket )
75
+
76
+ if len (self ._expected_failures ) > 0 :
77
+ raise self ._expected_failures .pop ()
78
+
79
+ table = pa .Table .from_pydict ({"ids" : [42 , 1337 , 1234 ]})
80
+ return GeneratorStream (schema = table .schema , generator = table .to_batches ())
81
+
82
+ def do_action (self , context : Any , action : ActionParam ) -> list [bytes ]:
83
+ self ._actions .append (action )
84
+
85
+ if len (self ._expected_failures ) > 0 :
86
+ raise self ._expected_failures .pop ()
87
+
88
+ if isinstance (action , Action ):
89
+ actionType = action .type
90
+ elif isinstance (action , tuple ):
91
+ actionType = action [0 ]
92
+ elif isinstance (action , str ):
93
+ actionType = action
94
+
95
+ response : dict [str , Any ] = {}
96
+ if "CREATE" in actionType :
97
+ response = {"name" : "g" }
98
+ elif "NODE_LOAD_DONE" in actionType :
99
+ response = {"name" : "g" , "node_count" : 42 }
100
+ elif "RELATIONSHIP_LOAD_DONE" in actionType :
101
+ response = {"name" : "g" , "relationship_count" : 42 }
102
+ elif "TRIPLET_LOAD_DONE" in actionType :
103
+ response = {"name" : "g" , "node_count" : 42 , "relationship_count" : 1337 }
104
+ else :
105
+ response = {}
106
+ return [json .dumps (response ).encode ("utf-8" )]
107
+
108
+
52
109
@pytest .fixture ()
53
110
def flight_server () -> Generator [None , FlightServer , None ]:
54
111
with FlightServer () as server :
55
112
yield server
56
113
57
114
115
+ @pytest .fixture ()
116
+ def flaky_flight_server () -> Generator [None , FlakyFlightServer , None ]:
117
+ with FlakyFlightServer () as server :
118
+ yield server
119
+
120
+
58
121
@pytest .fixture ()
59
122
def flight_client (flight_server : FlightServer ) -> Generator [GdsArrowClient , None , None ]:
60
123
with GdsArrowClient ("localhost" , flight_server .port ) as client :
61
124
yield client
62
125
63
126
127
+ @pytest .fixture ()
128
+ def flaky_flight_client (flaky_flight_server : FlakyFlightServer ) -> Generator [GdsArrowClient , None , None ]:
129
+ with GdsArrowClient ("localhost" , flaky_flight_server .port ) as client :
130
+ yield client
131
+
132
+
64
133
def test_create_graph_with_defaults (flight_server : FlightServer , flight_client : GdsArrowClient ) -> None :
65
134
flight_client .create_graph ("g" , "DB" )
66
135
actions = flight_server ._actions
@@ -87,6 +156,15 @@ def test_create_graph_with_options(flight_server: FlightServer, flight_client: G
87
156
)
88
157
89
158
159
+ def test_create_graph_with_flaky_server (
160
+ flaky_flight_server : FlakyFlightServer , flaky_flight_client : GdsArrowClient
161
+ ) -> None :
162
+ flaky_flight_client .create_graph ("g" , "DB" )
163
+ actions = flaky_flight_server ._actions
164
+ assert len (actions ) == flaky_flight_server .expected_retries ()
165
+ assert_action (actions [0 ], "v1/CREATE_GRAPH" , {"name" : "g" , "database_name" : "DB" })
166
+
167
+
90
168
def test_create_graph_from_triplets_with_defaults (flight_server : FlightServer , flight_client : GdsArrowClient ) -> None :
91
169
flight_client .create_graph_from_triplets ("g" , "DB" )
92
170
actions = flight_server ._actions
@@ -202,6 +280,22 @@ def test_get_node_property(flight_server: FlightServer, flight_client: GdsArrowC
202
280
)
203
281
204
282
283
+ def test_flakey_get_node_property (flaky_flight_server : FlakyFlightServer , flaky_flight_client : GdsArrowClient ) -> None :
284
+ flaky_flight_client .get_node_properties ("g" , "db" , "id" , ["Person" ], concurrency = 42 )
285
+ tickets = flaky_flight_server ._tickets
286
+ assert len (tickets ) == flaky_flight_server .expected_retries ()
287
+ assert_ticket (
288
+ tickets [0 ],
289
+ {
290
+ "concurrency" : 42 ,
291
+ "configuration" : {"list_node_labels" : False , "node_labels" : ["Person" ], "node_property" : "id" },
292
+ "database_name" : "db" ,
293
+ "graph_name" : "g" ,
294
+ "procedure_name" : "gds.graph.nodeProperty.stream" ,
295
+ },
296
+ )
297
+
298
+
205
299
def test_get_node_properties (flight_server : FlightServer , flight_client : GdsArrowClient ) -> None :
206
300
flight_client .get_node_properties ("g" , "db" , ["foo" , "bar" ], ["Person" ], list_node_labels = True , concurrency = 42 )
207
301
tickets = flight_server ._tickets
@@ -314,21 +408,21 @@ def test_auth_middleware_bad_headers() -> None:
314
408
315
409
def test_handle_flight_error () -> None :
316
410
with pytest .raises (
317
- flight . FlightServerError ,
411
+ FlightServerError ,
318
412
match = "FlightServerError: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database." ,
319
413
):
320
414
GdsArrowClient .handle_flight_error (
321
- flight . FlightServerError (
415
+ FlightServerError (
322
416
'FlightServerError: Flight RPC failed with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database.. gRPC client debug context: UNKNOWN:Error received from peer ipv4:35.241.177.75:8491 {created_time:"2024-08-29T15:59:03.828903999+02:00", grpc_status:2, grpc_message:"org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Graph with name `people-and-fruits` does not exist on database `neo4j`. It might exist on another database."}. Client context: IOError: Server never sent a data message. Detail: Internal'
323
417
)
324
418
)
325
419
326
420
with pytest .raises (
327
- flight . FlightServerError ,
421
+ FlightServerError ,
328
422
match = re .escape ("FlightServerError: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]" ),
329
423
):
330
424
GdsArrowClient .handle_flight_error (
331
- flight . FlightServerError (
425
+ FlightServerError (
332
426
"FlightServerError: Flight returned internal error, with message: org.apache.arrow.flight.FlightRuntimeException: UNKNOWN: Unexpected configuration key(s): [undirectedRelationshipTypes]"
333
427
)
334
428
)
0 commit comments