|
16 | 16 | import numpy.ma as ma
|
17 | 17 | import numpy.testing as npt
|
18 | 18 | import pandas as pd
|
| 19 | +import pytest |
19 | 20 | import scipy.sparse as sps
|
20 | 21 | import theano
|
21 | 22 | import theano.sparse as sparse
|
|
25 | 26 |
|
26 | 27 |
|
27 | 28 | class TestHelperFunc:
|
28 |
| - def test_pandas_to_array(self): |
| 29 | + @pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"]) |
| 30 | + def test_pandas_to_array(self, input_dtype): |
29 | 31 | """
|
30 | 32 | Ensure that pandas_to_array returns the dense array, masked array,
|
31 | 33 | graph variable, TensorVariable, or sparse matrix as appropriate.
|
32 | 34 | """
|
33 | 35 | # Create the various inputs to the function
|
34 |
| - sparse_input = sps.csr_matrix(np.eye(3)) |
35 |
| - dense_input = np.arange(9).reshape((3, 3)) |
| 36 | + sparse_input = sps.csr_matrix(np.eye(3)).astype(input_dtype) |
| 37 | + dense_input = np.arange(9).reshape((3, 3)).astype(input_dtype) |
36 | 38 |
|
37 | 39 | input_name = "input_variable"
|
38 | 40 | theano_graph_input = tt.as_tensor(dense_input, name=input_name)
|
39 |
| - |
40 | 41 | pandas_input = pd.DataFrame(dense_input)
|
41 | 42 |
|
42 | 43 | # All the even numbers are replaced with NaN
|
@@ -81,7 +82,18 @@ def test_pandas_to_array(self):
|
81 | 82 | theano_output = func(theano_graph_input)
|
82 | 83 | assert isinstance(theano_output, theano.graph.basic.Variable)
|
83 | 84 | npt.assert_allclose(theano_output.eval(), theano_graph_input.eval())
|
84 |
| - assert theano_output.owner.inputs[0].name == input_name |
| 85 | + intX = pm.theanof._conversion_map[theano.config.floatX] |
| 86 | + if dense_input.dtype == intX or dense_input.dtype == theano.config.floatX: |
| 87 | + assert theano_output.owner is None # func should not have added new nodes |
| 88 | + assert theano_output.name == input_name |
| 89 | + else: |
| 90 | + assert theano_output.owner is not None # func should have casted |
| 91 | + assert theano_output.owner.inputs[0].name == input_name |
| 92 | + |
| 93 | + if "float" in input_dtype: |
| 94 | + assert theano_output.dtype == theano.config.floatX |
| 95 | + else: |
| 96 | + assert theano_output.dtype == intX |
85 | 97 |
|
86 | 98 | # Check function behavior with generator data
|
87 | 99 | generator_output = func(square_generator)
|
|
0 commit comments