@@ -77,15 +77,20 @@ def test_terminate(self, ss, cc):
77
77
ss .process_message (cc , {"id" : "1" , "type" : constants .GQL_CONNECTION_TERMINATE })
78
78
ss .on_connection_terminate .assert_called_with (cc , "1" )
79
79
80
- def test_start (self , ss , cc ):
80
+ @pytest .mark .parametrize (
81
+ "transport_ws_protocol,expected_type" ,
82
+ ((False , constants .GQL_START ), (True , constants .GQL_SUBSCRIBE )),
83
+ )
84
+ def test_start (self , ss , cc , transport_ws_protocol , expected_type ):
81
85
ss .get_graphql_params = mock .Mock ()
82
86
ss .get_graphql_params .return_value = {"params" : True }
83
87
cc .has_operation = mock .Mock ()
84
88
cc .has_operation .return_value = False
89
+ cc .transport_ws_protocol = transport_ws_protocol
85
90
ss .unsubscribe = mock .Mock ()
86
91
ss .on_start = mock .Mock ()
87
92
ss .process_message (
88
- cc , {"id" : "1" , "type" : constants . GQL_START , "payload" : {"a" : "b" }}
93
+ cc , {"id" : "1" , "type" : expected_type , "payload" : {"a" : "b" }}
89
94
)
90
95
assert not ss .unsubscribe .called
91
96
ss .on_start .assert_called_with (cc , "1" , {"params" : True })
@@ -117,9 +122,32 @@ def test_start_bad_graphql_params(self, ss, cc):
117
122
assert isinstance (ss .send_error .call_args [0 ][2 ], Exception )
118
123
assert not ss .on_start .called
119
124
120
- def test_stop (self , ss , cc ):
125
+ @pytest .mark .parametrize (
126
+ "transport_ws_protocol,stop_type,invalid_stop_type" ,
127
+ (
128
+ (False , constants .GQL_STOP , constants .GQL_COMPLETE ),
129
+ (True , constants .GQL_COMPLETE , constants .GQL_STOP ),
130
+ ),
131
+ )
132
+ def test_stop (
133
+ self ,
134
+ ss ,
135
+ cc ,
136
+ transport_ws_protocol ,
137
+ stop_type ,
138
+ invalid_stop_type ,
139
+ ):
121
140
ss .on_stop = mock .Mock ()
122
- ss .process_message (cc , {"id" : "1" , "type" : constants .GQL_STOP })
141
+ ss .send_error = mock .Mock ()
142
+ cc .transport_ws_protocol = transport_ws_protocol
143
+
144
+ ss .process_message (cc , {"id" : "1" , "type" : invalid_stop_type })
145
+ assert ss .send_error .called
146
+ assert ss .send_error .call_args [0 ][:2 ] == (cc , "1" )
147
+ assert isinstance (ss .send_error .call_args [0 ][2 ], Exception )
148
+ assert not ss .on_stop .called
149
+
150
+ ss .process_message (cc , {"id" : "1" , "type" : stop_type })
123
151
ss .on_stop .assert_called_with (cc , "1" )
124
152
125
153
def test_invalid (self , ss , cc ):
@@ -165,13 +193,18 @@ def test_build_message_partial(ss):
165
193
ss .build_message (id = None , op_type = None , payload = None )
166
194
167
195
168
- def test_send_execution_result (ss , cc ):
196
+ @pytest .mark .parametrize (
197
+ "transport_ws_protocol,expected_type" ,
198
+ ((False , constants .GQL_DATA ), (True , constants .GQL_NEXT )),
199
+ )
200
+ def test_send_execution_result (ss , cc , transport_ws_protocol , expected_type ):
201
+ cc .transport_ws_protocol = transport_ws_protocol
169
202
ss .execution_result_to_dict = mock .Mock ()
170
203
ss .execution_result_to_dict .return_value = {"res" : "ult" }
171
204
ss .send_message = mock .Mock ()
172
205
ss .send_message .return_value = "returned"
173
206
assert "returned" == ss .send_execution_result (cc , "1" , "result" )
174
- ss .send_message .assert_called_with (cc , "1" , constants . GQL_DATA , {"res" : "ult" })
207
+ ss .send_message .assert_called_with (cc , "1" , expected_type , {"res" : "ult" })
175
208
176
209
177
210
def test_execution_result_to_dict (ss ):
0 commit comments