1
+ import base64
1
2
import logging
2
- from subprocess import call
3
+ from datetime import datetime , timezone
3
4
from typing import Sequence , Tuple
4
5
5
6
import clickhouse_driver
13
14
ClickHouseChain ,
14
15
ClickHouseRun ,
15
16
create_chain_table ,
16
- create_runs_table ,
17
17
)
18
- from mcbackend .core import Chain , Run , chain_id
18
+ from mcbackend .core import Run , chain_id
19
19
from mcbackend .meta import ChainMeta , RunMeta , Variable
20
20
from mcbackend .test_utils import CheckBehavior , CheckPerformance , make_runmeta
21
21
@@ -137,6 +137,19 @@ def test_init_run(self):
137
137
assert isinstance (runs , pandas .DataFrame )
138
138
assert runs .index .name == "rid"
139
139
assert "my_first_run" in runs .index .values
140
+
141
+ # Illegaly create a duplicate entry
142
+ created_at = datetime .now ().astimezone (timezone .utc )
143
+ query = "INSERT INTO runs (created_at, rid, proto) VALUES"
144
+ params = dict (
145
+ created_at = created_at ,
146
+ rid = meta .rid ,
147
+ proto = base64 .encodebytes (bytes (meta )).decode ("ascii" ),
148
+ )
149
+ self ._client .execute (query , [params ])
150
+ assert len (self ._client .execute ("SELECT * FROM runs;" )) == 2
151
+ with pytest .raises (Exception , match = "Unexpected number of 2 results" ):
152
+ self .backend .get_run ("my_first_run" )
140
153
pass
141
154
142
155
def test_get_run (self ):
@@ -160,6 +173,10 @@ def test_create_chain_table(self):
160
173
self .backend .init_run (rmeta )
161
174
cmeta = ChainMeta (rmeta .rid , 1 )
162
175
create_chain_table (self ._client , cmeta , rmeta )
176
+
177
+ with pytest .raises (Exception , match = "already exists" ):
178
+ create_chain_table (self ._client , cmeta , rmeta )
179
+
163
180
rows , names_and_types = self ._client .execute (
164
181
f"SELECT * FROM { chain_id (cmeta )} ;" , with_column_types = True
165
182
)
@@ -185,6 +202,23 @@ def test_create_chain_table_with_undefined_ndim(self, caplog):
185
202
assert "Assuming ndim=0" in caplog .records [0 ].message
186
203
pass
187
204
205
+ def test_get_chains_via_query (self ):
206
+ run , chains = fully_initialized (
207
+ self .backend ,
208
+ make_runmeta (
209
+ variables = [
210
+ Variable ("v1" , "uint16" , []),
211
+ Variable ("v2" , "float32" , list ((3 ,))),
212
+ Variable ("v3" , "float64" , [2 , 5 , 6 ]),
213
+ ],
214
+ ),
215
+ )
216
+ newrun = ClickHouseRun (run .meta , client_fn = self .backend ._client_fn )
217
+ chains_fetched = newrun .get_chains ()
218
+ assert len (chains_fetched ) > 0
219
+ assert len (chains_fetched ) == len (chains )
220
+ pass
221
+
188
222
def test_insert_draw (self ):
189
223
run , chains = fully_initialized (
190
224
self .backend ,
@@ -202,6 +236,10 @@ def test_insert_draw(self):
202
236
"v3" : numpy .random .uniform (size = (2 , 5 , 6 )).astype ("float64" ),
203
237
}
204
238
chain = chains [0 ]
239
+
240
+ with pytest .raises (Exception , match = "No draws in chain" ):
241
+ chain ._get_rows ("v1" , [], "uint16" )
242
+
205
243
chain .append (draw )
206
244
assert len (chain ._insert_queue ) == 1
207
245
chain ._commit ()
@@ -213,6 +251,9 @@ def test_insert_draw(self):
213
251
assert v1 == 12
214
252
numpy .testing .assert_array_equal (v2 , draw ["v2" ])
215
253
numpy .testing .assert_array_equal (v3 , draw ["v3" ])
254
+
255
+ with pytest .raises (Exception , match = "No record found for draw" ):
256
+ chain ._get_row_at (2 , var_names = ["v1" ])
216
257
pass
217
258
218
259
0 commit comments