Skip to content

Commit dff8fb9

Browse files
Spaaktwiecki
authored andcommitted
updating test_pandas_to_array for Theano 1.1.0 compatibility
1 parent 6ad83f6 commit dff8fb9

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

pymc3/tests/test_model_helpers.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy.ma as ma
1717
import numpy.testing as npt
1818
import pandas as pd
19+
import pytest
1920
import scipy.sparse as sps
2021
import theano
2122
import theano.sparse as sparse
@@ -25,18 +26,18 @@
2526

2627

2728
class TestHelperFunc:
28-
def test_pandas_to_array(self):
29+
@pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"])
30+
def test_pandas_to_array(self, input_dtype):
2931
"""
3032
Ensure that pandas_to_array returns the dense array, masked array,
3133
graph variable, TensorVariable, or sparse matrix as appropriate.
3234
"""
3335
# Create the various inputs to the function
34-
sparse_input = sps.csr_matrix(np.eye(3))
35-
dense_input = np.arange(9).reshape((3, 3))
36+
sparse_input = sps.csr_matrix(np.eye(3)).astype(input_dtype)
37+
dense_input = np.arange(9).reshape((3, 3)).astype(input_dtype)
3638

3739
input_name = "input_variable"
3840
theano_graph_input = tt.as_tensor(dense_input, name=input_name)
39-
4041
pandas_input = pd.DataFrame(dense_input)
4142

4243
# All the even numbers are replaced with NaN
@@ -81,7 +82,18 @@ def test_pandas_to_array(self):
8182
theano_output = func(theano_graph_input)
8283
assert isinstance(theano_output, theano.graph.basic.Variable)
8384
npt.assert_allclose(theano_output.eval(), theano_graph_input.eval())
84-
assert theano_output.owner.inputs[0].name == input_name
85+
intX = pm.theanof._conversion_map[theano.config.floatX]
86+
if dense_input.dtype == intX or dense_input.dtype == theano.config.floatX:
87+
assert theano_output.owner is None # func should not have added new nodes
88+
assert theano_output.name == input_name
89+
else:
90+
assert theano_output.owner is not None # func should have casted
91+
assert theano_output.owner.inputs[0].name == input_name
92+
93+
if "float" in input_dtype:
94+
assert theano_output.dtype == theano.config.floatX
95+
else:
96+
assert theano_output.dtype == intX
8597

8698
# Check function behavior with generator data
8799
generator_output = func(square_generator)

0 commit comments

Comments
 (0)