Skip to content

Better coverage for float32 tests #6780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ jobs:
floatx: [float32]
python-version: ["3.11"]
test-subset:
- tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py
- tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py tests/distributions/test_transform.py
fail-fast: false
runs-on: ${{ matrix.os }}
env:
Expand Down
6 changes: 5 additions & 1 deletion pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,8 +958,10 @@ class SimplexTransform(RVTransform):
name = "simplex"

def forward(self, value, *inputs):
value = pt.as_tensor(value)
log_value = pt.log(value)
shift = pt.sum(log_value, -1, keepdims=True) / value.shape[-1]
N = value.shape[-1].astype(value.dtype)
shift = pt.sum(log_value, -1, keepdims=True) / N
return log_value[..., :-1] - shift

def backward(self, value, *inputs):
Expand All @@ -968,7 +970,9 @@ def backward(self, value, *inputs):
return exp_value_max / pt.sum(exp_value_max, -1, keepdims=True)

def log_jac_det(self, value, *inputs):
value = pt.as_tensor(value)
N = value.shape[-1] + 1
N = N.astype(value.dtype)
sum_value = pt.sum(value, -1, keepdims=True)
value_sum_expanded = value + sum_value
value_sum_expanded = pt.concatenate([value_sum_expanded, pt.zeros(sum_value.shape)], -1)
Expand Down
97 changes: 61 additions & 36 deletions tests/distributions/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@

# some transforms (stick breaking) require addition of small slack in order to be numerically
# stable. The minimal addable slack for float32 is higher thus we need to be less strict
tol = 1e-7 if pytensor.config.floatX == "float64" else 1e-6
tol = 1e-7 if pytensor.config.floatX == "float64" else 1e-5
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just in case, float32 was not checked in the CI so the previous tolerance was not taken in account



def check_transform(transform, domain, constructor=pt.dscalar, test=0, rv_var=None):
def check_transform(transform, domain, constructor=pt.scalar, test=0, rv_var=None):
x = constructor("x")
x.tag.test_value = test
if rv_var is None:
Expand All @@ -57,18 +57,20 @@ def check_transform(transform, domain, constructor=pt.dscalar, test=0, rv_var=No
# FIXME: What's being tested here? That the transformed graph can compile?
forward_f = pytensor.function([x], transform.forward(x, *rv_inputs))
# test transform identity
identity_f = pytensor.function(
[x], transform.backward(transform.forward(x, *rv_inputs), *rv_inputs)
)
z = transform.backward(transform.forward(x, *rv_inputs))
assert z.type == x.type
identity_f = pytensor.function([x], z, *rv_inputs)
for val in domain.vals:
close_to(val, identity_f(val), tol)


def check_vector_transform(transform, domain, rv_var=None):
return check_transform(transform, domain, pt.dvector, test=np.array([0, 0]), rv_var=rv_var)
return check_transform(
transform, domain, pt.vector, test=floatX(np.array([0, 0])), rv_var=rv_var
)


def get_values(transform, domain=R, constructor=pt.dscalar, test=0, rv_var=None):
def get_values(transform, domain=R, constructor=pt.scalar, test=0, rv_var=None):
x = constructor("x")
x.tag.test_value = test
if rv_var is None:
Expand All @@ -81,7 +83,7 @@ def get_values(transform, domain=R, constructor=pt.dscalar, test=0, rv_var=None)
def check_jacobian_det(
transform,
domain,
constructor=pt.dscalar,
constructor=pt.scalar,
test=0,
make_comparable=None,
elemwise=False,
Expand Down Expand Up @@ -119,22 +121,26 @@ def test_simplex():
check_vector_transform(tr.simplex, Simplex(2))
check_vector_transform(tr.simplex, Simplex(4))

check_transform(tr.simplex, MultiSimplex(3, 2), constructor=pt.dmatrix, test=np.zeros((2, 2)))
check_transform(
tr.simplex, MultiSimplex(3, 2), constructor=pt.matrix, test=floatX(np.zeros((2, 2)))
)


def test_simplex_bounds():
vals = get_values(tr.simplex, Vector(R, 2), pt.dvector, np.array([0, 0]))
vals = get_values(tr.simplex, Vector(R, 2), pt.vector, floatX(np.array([0, 0])))

close_to(vals.sum(axis=1), 1, tol)
close_to_logical(vals > 0, True, tol)
close_to_logical(vals < 1, True, tol)

check_jacobian_det(tr.simplex, Vector(R, 2), pt.dvector, np.array([0, 0]), lambda x: x[:-1])
check_jacobian_det(
tr.simplex, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), lambda x: x[:-1]
)


def test_simplex_accuracy():
val = np.array([-30])
x = pt.dvector("x")
val = floatX(np.array([-30]))
x = pt.vector("x")
x.tag.test_value = val
identity_f = pytensor.function([x], tr.simplex.forward(x, tr.simplex.backward(x, x)))
close_to(val, identity_f(val), tol)
Expand All @@ -148,28 +154,39 @@ def test_sum_to_1():
tr.SumTo1(2)

check_jacobian_det(
tr.univariate_sum_to_1, Vector(Unit, 2), pt.dvector, np.array([0, 0]), lambda x: x[:-1]
tr.univariate_sum_to_1,
Vector(Unit, 2),
pt.vector,
floatX(np.array([0, 0])),
lambda x: x[:-1],
)
check_jacobian_det(
tr.multivariate_sum_to_1, Vector(Unit, 2), pt.dvector, np.array([0, 0]), lambda x: x[:-1]
tr.multivariate_sum_to_1,
Vector(Unit, 2),
pt.vector,
floatX(np.array([0, 0])),
lambda x: x[:-1],
)


def test_log():
check_transform(tr.log, Rplusbig)

check_jacobian_det(tr.log, Rplusbig, elemwise=True)
check_jacobian_det(tr.log, Vector(Rplusbig, 2), pt.dvector, [0, 0], elemwise=True)
check_jacobian_det(tr.log, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True)

vals = get_values(tr.log)
close_to_logical(vals > 0, True, tol)


@pytest.mark.skipif(
pytensor.config.floatX == "float32", reason="Test is designed for 64bit precision"
)
def test_log_exp_m1():
check_transform(tr.log_exp_m1, Rplusbig)

check_jacobian_det(tr.log_exp_m1, Rplusbig, elemwise=True)
check_jacobian_det(tr.log_exp_m1, Vector(Rplusbig, 2), pt.dvector, [0, 0], elemwise=True)
check_jacobian_det(tr.log_exp_m1, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True)

vals = get_values(tr.log_exp_m1)
close_to_logical(vals > 0, True, tol)
Expand All @@ -179,7 +196,7 @@ def test_logodds():
check_transform(tr.logodds, Unit)

check_jacobian_det(tr.logodds, Unit, elemwise=True)
check_jacobian_det(tr.logodds, Vector(Unit, 2), pt.dvector, [0.5, 0.5], elemwise=True)
check_jacobian_det(tr.logodds, Vector(Unit, 2), pt.vector, [0.5, 0.5], elemwise=True)

vals = get_values(tr.logodds)
close_to_logical(vals > 0, True, tol)
Expand All @@ -191,7 +208,7 @@ def test_lowerbound():
check_transform(trans, Rplusbig)

check_jacobian_det(trans, Rplusbig, elemwise=True)
check_jacobian_det(trans, Vector(Rplusbig, 2), pt.dvector, [0, 0], elemwise=True)
check_jacobian_det(trans, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True)

vals = get_values(trans)
close_to_logical(vals > 0, True, tol)
Expand All @@ -202,7 +219,7 @@ def test_upperbound():
check_transform(trans, Rminusbig)

check_jacobian_det(trans, Rminusbig, elemwise=True)
check_jacobian_det(trans, Vector(Rminusbig, 2), pt.dvector, [-1, -1], elemwise=True)
check_jacobian_det(trans, Vector(Rminusbig, 2), pt.vector, [-1, -1], elemwise=True)

vals = get_values(trans)
close_to_logical(vals < 0, True, tol)
Expand Down Expand Up @@ -234,7 +251,7 @@ def test_interval_near_boundary():
pm.Uniform("x", initval=x0, lower=lb, upper=ub)

log_prob = model.point_logps()
np.testing.assert_allclose(list(log_prob.values()), np.array([-52.68]))
np.testing.assert_allclose(list(log_prob.values()), floatX(np.array([-52.68])))


def test_circular():
Expand All @@ -257,19 +274,19 @@ def test_ordered():
tr.Ordered(2)

check_jacobian_det(
tr.univariate_ordered, Vector(R, 2), pt.dvector, np.array([0, 0]), elemwise=False
tr.univariate_ordered, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), elemwise=False
)
check_jacobian_det(
tr.multivariate_ordered, Vector(R, 2), pt.dvector, np.array([0, 0]), elemwise=False
tr.multivariate_ordered, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), elemwise=False
)

vals = get_values(tr.univariate_ordered, Vector(R, 3), pt.dvector, np.zeros(3))
vals = get_values(tr.univariate_ordered, Vector(R, 3), pt.vector, floatX(np.zeros(3)))
close_to_logical(np.diff(vals) >= 0, True, tol)


def test_chain_values():
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
vals = get_values(chain_tranf, Vector(R, 5), pt.dvector, np.zeros(5))
vals = get_values(chain_tranf, Vector(R, 5), pt.vector, floatX(np.zeros(5)))
close_to_logical(np.diff(vals) >= 0, True, tol)


Expand All @@ -281,7 +298,7 @@ def test_chain_vector_transform():
@pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.")
def test_chain_jacob_det():
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
check_jacobian_det(chain_tranf, Vector(R, 4), pt.dvector, np.zeros(4), elemwise=False)
check_jacobian_det(chain_tranf, Vector(R, 4), pt.vector, floatX(np.zeros(4)), elemwise=False)


class TestElementWiseLogp(SeededTest):
Expand Down Expand Up @@ -432,7 +449,7 @@ def transform_params(*inputs):
[
(0.0, 1.0, 2.0, 2),
(-10, 0, 200, (2, 3)),
(np.zeros(3), np.ones(3), np.ones(3), (4, 3)),
(floatX(np.zeros(3)), floatX(np.ones(3)), floatX(np.ones(3)), (4, 3)),
],
)
def test_triangular(self, lower, c, upper, size):
Expand All @@ -449,7 +466,8 @@ def transform_params(*inputs):
self.check_transform_elementwise_logp(model)

@pytest.mark.parametrize(
"mu,kappa,size", [(0.0, 1.0, 2), (-0.5, 5.5, (2, 3)), (np.zeros(3), np.ones(3), (4, 3))]
"mu,kappa,size",
[(0.0, 1.0, 2), (-0.5, 5.5, (2, 3)), (floatX(np.zeros(3)), floatX(np.ones(3)), (4, 3))],
)
def test_vonmises(self, mu, kappa, size):
model = self.build_model(
Expand Down Expand Up @@ -549,7 +567,9 @@ def transform_params(*inputs):
)
self.check_vectortransform_elementwise_logp(model)

@pytest.mark.parametrize("mu,kappa,size", [(0.0, 1.0, (2,)), (np.zeros(3), np.ones(3), (4, 3))])
@pytest.mark.parametrize(
"mu,kappa,size", [(0.0, 1.0, (2,)), (floatX(np.zeros(3)), floatX(np.ones(3)), (4, 3))]
)
def test_vonmises_ordered(self, mu, kappa, size):
initval = np.sort(np.abs(np.random.rand(*size)))
model = self.build_model(
Expand All @@ -566,7 +586,12 @@ def test_vonmises_ordered(self, mu, kappa, size):
[
(0.0, 1.0, (2,), tr.simplex),
(0.5, 5.5, (2, 3), tr.simplex),
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.univariate_sum_to_1, tr.logodds])),
(
floatX(np.zeros(3)),
floatX(np.ones(3)),
(4, 3),
tr.Chain([tr.univariate_sum_to_1, tr.logodds]),
),
],
)
def test_uniform_other(self, lower, upper, size, transform):
Expand All @@ -583,8 +608,8 @@ def test_uniform_other(self, lower, upper, size, transform):
@pytest.mark.parametrize(
"mu,cov,size,shape",
[
(np.zeros(2), np.diag(np.ones(2)), None, (2,)),
(np.zeros(3), np.diag(np.ones(3)), (4,), (4, 3)),
(floatX(np.zeros(2)), floatX(np.diag(np.ones(2))), None, (2,)),
(floatX(np.zeros(3)), floatX(np.diag(np.ones(3))), (4,), (4, 3)),
],
)
def test_mvnormal_ordered(self, mu, cov, size, shape):
Expand Down Expand Up @@ -643,7 +668,7 @@ def test_2d_univariate_ordered():
)

log_p = model.compile_logp(sum=False)(
{"x_1d_ordered__": np.zeros((4,)), "x_2d_ordered__": np.zeros((10, 4))}
{"x_1d_ordered__": floatX(np.zeros((4,))), "x_2d_ordered__": floatX(np.zeros((10, 4)))}
)
np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1])

Expand All @@ -667,7 +692,7 @@ def test_2d_multivariate_ordered():
)

log_p = model.compile_logp(sum=False)(
{"x_1d_ordered__": np.zeros((2,)), "x_2d_ordered__": np.zeros((2, 2))}
{"x_1d_ordered__": floatX(np.zeros((2,))), "x_2d_ordered__": floatX(np.zeros((2, 2)))}
)
np.testing.assert_allclose(log_p[0], log_p[1])

Expand All @@ -690,7 +715,7 @@ def test_2d_univariate_sum_to_1():
)

log_p = model.compile_logp(sum=False)(
{"x_1d_sumto1__": np.zeros(3), "x_2d_sumto1__": np.zeros((10, 3))}
{"x_1d_sumto1__": floatX(np.zeros(3)), "x_2d_sumto1__": floatX(np.zeros((10, 3)))}
)
np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1])

Expand All @@ -712,6 +737,6 @@ def test_2d_multivariate_sum_to_1():
)

log_p = model.compile_logp(sum=False)(
{"x_1d_sumto1__": np.zeros(1), "x_2d_sumto1__": np.zeros((2, 1))}
{"x_1d_sumto1__": floatX(np.zeros(1)), "x_2d_sumto1__": floatX(np.zeros((2, 1)))}
)
np.testing.assert_allclose(log_p[0], log_p[1])