diff --git a/src/flint/flint_base/flint_base.pyx b/src/flint/flint_base/flint_base.pyx index 3b5c7434..4e33ae6e 100644 --- a/src/flint/flint_base/flint_base.pyx +++ b/src/flint/flint_base/flint_base.pyx @@ -66,14 +66,38 @@ cdef class flint_poly(flint_elem): def roots(self): """ - Deprecated function. - - To recover roots of a polynomial, first convert to acb: - - acb_poly(input_poly).roots() + Computes all the roots in the base ring of the polynomial. + Returns a list of all pairs (*v*, *m*) where *v* is the + integer root and *m* is the multiplicity of the root. + + To compute complex roots of a polynomial, instead use + the `.complex_roots()` method, which is available on + certain polynomial rings. + + >>> from flint import fmpz_poly + >>> fmpz_poly([1, 2]).roots() + [] + >>> fmpz_poly([2, 1]).roots() + [(-2, 1)] + >>> fmpz_poly([12, 7, 1]).roots() + [(-3, 1), (-4, 1)] + >>> (fmpz_poly([-5,1]) * fmpz_poly([-5,1]) * fmpz_poly([-3,1])).roots() + [(3, 1), (5, 2)] """ - raise NotImplementedError('This method is no longer supported. To recover the complex roots first convert to acb_poly') - + factor_fn = getattr(self, "factor", None) + if not callable(factor_fn): + raise NotImplementedError("Polynomial has no factor method, roots cannot be determined") + + roots = [] + factors = self.factor() + for fac, m in factors[1]: + if fac.degree() == fac[1] == 1: + v = - fac[0] + roots.append((v, m)) + return roots + + def complex_roots(self): + raise AttributeError("Complex roots are not supported for this polynomial") cdef class flint_mpoly(flint_elem): diff --git a/src/flint/test/test.py b/src/flint/test/test.py index 34100e88..c19a96d0 100644 --- a/src/flint/test/test.py +++ b/src/flint/test/test.py @@ -438,11 +438,17 @@ def test_fmpz_poly(): assert Z([1,2,2]).sqrt() is None assert Z([1,0,2,0,3]).deflation() == (Z([1,2,3]), 2) assert Z([1,1]).deflation() == (Z([1,1]), 1) - [(r,m)] = Z([1,1]).roots() + [(r,m)] = Z([1,1]).complex_roots() assert m == 1 assert r.overlaps(-1) + assert Z([]).complex_roots() == [] + assert Z([1]).complex_roots() == [] + [(r,m)] = Z([1,1]).roots() + assert m == 1 + assert r == -1 assert Z([]).roots() == [] assert Z([1]).roots() == [] + assert Z([1, 2]).roots() == [] def test_fmpz_poly_factor(): Z = flint.fmpz_poly @@ -985,11 +991,13 @@ def set_bad(): assert Q.bernoulli_poly(3) == Q([0,1,-3,2],2) assert Q.euler_poly(3) == Q([1,0,-6,4],4) assert Q.legendre_p(3) == Q([0,-3,0,5],2) - assert Q([]).roots() == [] - assert Q([1]).roots() == [] - [(r,m)] = Q([1,1]).roots() + assert Q([]).complex_roots() == [] + assert Q([1]).complex_roots() == [] + [(r,m)] = Q([1,1]).complex_roots() assert m == 1 assert r.overlaps(-1) + assert str(Q([1,2]).roots()) == "[(-1/2, 1)]" + assert Q([2,1]).roots() == [(-2, 1)] def test_fmpq_mat(): Q = flint.fmpq_mat @@ -1411,6 +1419,9 @@ def set_bad2(): for alg in [None, 'berlekamp', 'cantor-zassenhaus']: assert p3.factor(alg) == f3 assert p3.factor(algorithm=alg) == f3 + assert P([1], 11).roots() == [] + assert P([1, 2, 3], 11).roots() == [(8, 1), (6, 1)] + assert P([1, 6, 1, 8], 11).roots() == [(5, 3)] def test_nmod_mat(): M = flint.nmod_mat diff --git a/src/flint/types/acb_poly.pyx b/src/flint/types/acb_poly.pyx index db9dc77b..7deca804 100644 --- a/src/flint/types/acb_poly.pyx +++ b/src/flint/types/acb_poly.pyx @@ -412,6 +412,8 @@ cdef class acb_poly(flint_poly): return pyroots + complex_roots = roots + def root_bound(self): """Returns an upper bound for the absolute value of the roots of self.""" diff --git a/src/flint/types/arb_poly.pyx b/src/flint/types/arb_poly.pyx index 85ad978d..8ea7bee9 100644 --- a/src/flint/types/arb_poly.pyx +++ b/src/flint/types/arb_poly.pyx @@ -113,6 +113,13 @@ cdef class arb_poly(flint_poly): libc.stdlib.free(xs) return u + def complex_roots(self, **kwargs): + """ + Compute the complex roots of the polynomial by converting + from arb_poly to acb_poly + """ + return acb_poly(self).roots(**kwargs) + def evaluate(self, xs, algorithm='fast'): """ Multipoint evaluation: evaluates *self* at the list of diff --git a/src/flint/types/fmpq_poly.pyx b/src/flint/types/fmpq_poly.pyx index b485bd0a..8c85b28d 100644 --- a/src/flint/types/fmpq_poly.pyx +++ b/src/flint/types/fmpq_poly.pyx @@ -384,16 +384,16 @@ cdef class fmpq_poly(flint_poly): fac[i] = (base, exp) return c / self.denom(), fac - def roots(self, **kwargs): + def complex_roots(self, **kwargs): """ Computes the complex roots of this polynomial. See :meth:`.fmpz_poly.roots`. >>> from flint import fmpq - >>> fmpq_poly([fmpq(2,3),1]).roots() + >>> fmpq_poly([fmpq(2,3),1]).complex_roots() [([-0.666666666666667 +/- 3.34e-16], 1)] """ - return self.numer().roots(**kwargs) + return self.numer().complex_roots(**kwargs) @staticmethod def bernoulli_poly(n): diff --git a/src/flint/types/fmpz_poly.pyx b/src/flint/types/fmpz_poly.pyx index ba532a62..5bb8f56c 100644 --- a/src/flint/types/fmpz_poly.pyx +++ b/src/flint/types/fmpz_poly.pyx @@ -343,19 +343,19 @@ cdef class fmpz_poly(flint_poly): fmpz_poly_factor_clear(fac) return c, res - def roots(self, bint verbose=False): + def complex_roots(self, bint verbose=False): """ Computes all the complex roots of this polynomial. Returns a list of pairs (*c*, *m*) where *c* is the root as an *acb* and *m* is the multiplicity of the root. - >>> fmpz_poly([]).roots() + >>> fmpz_poly([]).complex_roots() [] - >>> fmpz_poly([1]).roots() + >>> fmpz_poly([1]).complex_roots() [] - >>> fmpz_poly([2,0,1]).roots() + >>> fmpz_poly([2,0,1]).complex_roots() [([1.41421356237310 +/- 4.96e-15]j, 1), ([-1.41421356237310 +/- 4.96e-15]j, 1)] - >>> for c, m in (fmpz_poly([2,3,4]) * fmpz_poly([5,6,7,11])**3).roots(): + >>> for c, m in (fmpz_poly([2,3,4]) * fmpz_poly([5,6,7,11])**3).complex_roots(): ... print((c,m)) ... ([-0.375000000000000 +/- 1.0e-19] + [0.599478940414090 +/- 5.75e-17]j, 1) diff --git a/src/flint/types/nmod.pyx b/src/flint/types/nmod.pyx index bf83f4da..e5e73eb6 100644 --- a/src/flint/types/nmod.pyx +++ b/src/flint/types/nmod.pyx @@ -81,6 +81,12 @@ cdef class nmod(flint_scalar): return res else: return not res + elif typecheck(s, nmod) and typecheck(t, int): + res = s.val == (t % s.mod.n) + if op == 2: + return res + else: + return not res return NotImplemented def __nonzero__(self):