Skip to content

Commit 586d2b6

Browse files
Increase test coverage
1 parent b33fe73 commit 586d2b6

File tree

3 files changed

+48
-7
lines changed

3 files changed

+48
-7
lines changed

mcbackend/test_adapter_pymc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ def simple_model():
2727
"condition": ["A", "B", "C"],
2828
}
2929
) as pmodel:
30-
x = pm.Data("seconds", seconds, dims="time")
30+
x = pm.ConstantData("seconds", seconds, dims="time")
3131
a = pm.Normal("scalar")
3232
b = pm.Uniform("vector", dims="condition")
3333
pm.Deterministic("matrix", a + b[:, None] * x[None, :], dims=("condition", "time"))
3434
pm.Bernoulli("integer", p=0.5)
35-
obs = pm.Data("obs", observations, dims=("condition", "time"))
35+
obs = pm.MutableData("obs", observations, dims=("condition", "time"))
3636
pm.Normal("L", pmodel["matrix"], observed=obs, dims=("condition", "time"))
3737
return pmodel
3838

@@ -128,7 +128,7 @@ def wrapper(meta: RunMeta):
128128
assert tuple(seconds.dims) == ("time",)
129129
assert not seconds.is_observed
130130
numpy.testing.assert_array_equal(
131-
ndarray_to_numpy(seconds.value), simple_model["seconds"].get_value()
131+
ndarray_to_numpy(seconds.value), simple_model["seconds"].data
132132
)
133133
# Observed data variables
134134
assert "obs" in datavars

mcbackend/test_backend_clickhouse.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import base64
12
import logging
2-
from subprocess import call
3+
from datetime import datetime, timezone
34
from typing import Sequence, Tuple
45

56
import clickhouse_driver
@@ -13,9 +14,8 @@
1314
ClickHouseChain,
1415
ClickHouseRun,
1516
create_chain_table,
16-
create_runs_table,
1717
)
18-
from mcbackend.core import Chain, Run, chain_id
18+
from mcbackend.core import Run, chain_id
1919
from mcbackend.meta import ChainMeta, RunMeta, Variable
2020
from mcbackend.test_utils import CheckBehavior, CheckPerformance, make_runmeta
2121

@@ -137,6 +137,19 @@ def test_init_run(self):
137137
assert isinstance(runs, pandas.DataFrame)
138138
assert runs.index.name == "rid"
139139
assert "my_first_run" in runs.index.values
140+
141+
# Illegaly create a duplicate entry
142+
created_at = datetime.now().astimezone(timezone.utc)
143+
query = "INSERT INTO runs (created_at, rid, proto) VALUES"
144+
params = dict(
145+
created_at=created_at,
146+
rid=meta.rid,
147+
proto=base64.encodebytes(bytes(meta)).decode("ascii"),
148+
)
149+
self._client.execute(query, [params])
150+
assert len(self._client.execute("SELECT * FROM runs;")) == 2
151+
with pytest.raises(Exception, match="Unexpected number of 2 results"):
152+
self.backend.get_run("my_first_run")
140153
pass
141154

142155
def test_get_run(self):
@@ -160,6 +173,10 @@ def test_create_chain_table(self):
160173
self.backend.init_run(rmeta)
161174
cmeta = ChainMeta(rmeta.rid, 1)
162175
create_chain_table(self._client, cmeta, rmeta)
176+
177+
with pytest.raises(Exception, match="already exists"):
178+
create_chain_table(self._client, cmeta, rmeta)
179+
163180
rows, names_and_types = self._client.execute(
164181
f"SELECT * FROM {chain_id(cmeta)};", with_column_types=True
165182
)
@@ -185,6 +202,23 @@ def test_create_chain_table_with_undefined_ndim(self, caplog):
185202
assert "Assuming ndim=0" in caplog.records[0].message
186203
pass
187204

205+
def test_get_chains_via_query(self):
206+
run, chains = fully_initialized(
207+
self.backend,
208+
make_runmeta(
209+
variables=[
210+
Variable("v1", "uint16", []),
211+
Variable("v2", "float32", list((3,))),
212+
Variable("v3", "float64", [2, 5, 6]),
213+
],
214+
),
215+
)
216+
newrun = ClickHouseRun(run.meta, client_fn=self.backend._client_fn)
217+
chains_fetched = newrun.get_chains()
218+
assert len(chains_fetched) > 0
219+
assert len(chains_fetched) == len(chains)
220+
pass
221+
188222
def test_insert_draw(self):
189223
run, chains = fully_initialized(
190224
self.backend,
@@ -202,6 +236,10 @@ def test_insert_draw(self):
202236
"v3": numpy.random.uniform(size=(2, 5, 6)).astype("float64"),
203237
}
204238
chain = chains[0]
239+
240+
with pytest.raises(Exception, match="No draws in chain"):
241+
chain._get_rows("v1", [], "uint16")
242+
205243
chain.append(draw)
206244
assert len(chain._insert_queue) == 1
207245
chain._commit()
@@ -213,6 +251,9 @@ def test_insert_draw(self):
213251
assert v1 == 12
214252
numpy.testing.assert_array_equal(v2, draw["v2"])
215253
numpy.testing.assert_array_equal(v3, draw["v3"])
254+
255+
with pytest.raises(Exception, match="No record found for draw"):
256+
chain._get_row_at(2, var_names=["v1"])
216257
pass
217258

218259

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
arviz
22
clickhouse-driver
33
flake8
4-
pymc==4.0.0b1
4+
pymc==4.0.0b2
55
pytest
66
pytest-cov
77
twine

0 commit comments

Comments
 (0)