"""Tools for solving inequalities and systems of inequalities. """
from __future__ import print_function, division
from sympy.core import Symbol, Dummy, sympify
from sympy.core.compatibility import iterable
from sympy.sets import Interval
from sympy.core.relational import Relational, Eq, Ge, Lt
from sympy.sets.sets import FiniteSet, Union
from sympy.sets.fancysets import ImageSet
from sympy.core.singleton import S
from sympy.functions import Abs
from sympy.logic import And
from sympy.polys import Poly, PolynomialError, parallel_poly_from_expr
from sympy.polys.polyutils import _nsort
from sympy.utilities.misc import filldedent
[docs]def solve_poly_inequality(poly, rel):
"""Solve a polynomial inequality with rational coefficients.
Examples
========
>>> from sympy import Poly
>>> from sympy.abc import x
>>> from sympy.solvers.inequalities import solve_poly_inequality
>>> solve_poly_inequality(Poly(x, x, domain='ZZ'), '==')
[{0}]
>>> solve_poly_inequality(Poly(x**2 - 1, x, domain='ZZ'), '!=')
[Interval.open(-oo, -1), Interval.open(-1, 1), Interval.open(1, oo)]
>>> solve_poly_inequality(Poly(x**2 - 1, x, domain='ZZ'), '==')
[{-1}, {1}]
See Also
========
solve_poly_inequalities
"""
if not isinstance(poly, Poly):
raise ValueError(
'For efficiency reasons, `poly` should be a Poly instance')
if poly.is_number:
t = Relational(poly.as_expr(), 0, rel)
if t is S.true:
return [S.Reals]
elif t is S.false:
return [S.EmptySet]
else:
raise NotImplementedError(
"could not determine truth value of %s" % t)
reals, intervals = poly.real_roots(multiple=False), []
if rel == '==':
for root, _ in reals:
interval = Interval(root, root)
intervals.append(interval)
elif rel == '!=':
left = S.NegativeInfinity
for right, _ in reals + [(S.Infinity, 1)]:
interval = Interval(left, right, True, True)
intervals.append(interval)
left = right
else:
if poly.LC() > 0:
sign = +1
else:
sign = -1
eq_sign, equal = None, False
if rel == '>':
eq_sign = +1
elif rel == '<':
eq_sign = -1
elif rel == '>=':
eq_sign, equal = +1, True
elif rel == '<=':
eq_sign, equal = -1, True
else:
raise ValueError("'%s' is not a valid relation" % rel)
right, right_open = S.Infinity, True
for left, multiplicity in reversed(reals):
if multiplicity % 2:
if sign == eq_sign:
intervals.insert(
0, Interval(left, right, not equal, right_open))
sign, right, right_open = -sign, left, not equal
else:
if sign == eq_sign and not equal:
intervals.insert(
0, Interval(left, right, True, right_open))
right, right_open = left, True
elif sign != eq_sign and equal:
intervals.insert(0, Interval(left, left))
if sign == eq_sign:
intervals.insert(
0, Interval(S.NegativeInfinity, right, True, right_open))
return intervals
def solve_poly_inequalities(polys):
"""Solve polynomial inequalities with rational coefficients.
Examples
========
>>> from sympy.solvers.inequalities import solve_poly_inequalities
>>> from sympy.polys import Poly
>>> from sympy.abc import x
>>> solve_poly_inequalities(((
... Poly(x**2 - 3), ">"), (
... Poly(-x**2 + 1), ">")))
Union(Interval.open(-oo, -sqrt(3)), Interval.open(-1, 1), Interval.open(sqrt(3), oo))
"""
from sympy import Union
return Union(*[solve_poly_inequality(*p) for p in polys])
[docs]def solve_rational_inequalities(eqs):
"""Solve a system of rational inequalities with rational coefficients.
Examples
========
>>> from sympy.abc import x
>>> from sympy import Poly
>>> from sympy.solvers.inequalities import solve_rational_inequalities
>>> solve_rational_inequalities([[
... ((Poly(-x + 1), Poly(1, x)), '>='),
... ((Poly(-x + 1), Poly(1, x)), '<=')]])
{1}
>>> solve_rational_inequalities([[
... ((Poly(x), Poly(1, x)), '!='),
... ((Poly(-x + 1), Poly(1, x)), '>=')]])
Union(Interval.open(-oo, 0), Interval.Lopen(0, 1))
See Also
========
solve_poly_inequality
"""
result = S.EmptySet
for _eqs in eqs:
if not _eqs:
continue
global_intervals = [Interval(S.NegativeInfinity, S.Infinity)]
for (numer, denom), rel in _eqs:
numer_intervals = solve_poly_inequality(numer*denom, rel)
denom_intervals = solve_poly_inequality(denom, '==')
intervals = []
for numer_interval in numer_intervals:
for global_interval in global_intervals:
interval = numer_interval.intersect(global_interval)
if interval is not S.EmptySet:
intervals.append(interval)
global_intervals = intervals
intervals = []
for global_interval in global_intervals:
for denom_interval in denom_intervals:
global_interval -= denom_interval
if global_interval is not S.EmptySet:
intervals.append(global_interval)
global_intervals = intervals
if not global_intervals:
break
for interval in global_intervals:
result = result.union(interval)
return result
[docs]def reduce_rational_inequalities(exprs, gen, relational=True):
"""Reduce a system of rational inequalities with rational coefficients.
Examples
========
>>> from sympy import Poly, Symbol
>>> from sympy.solvers.inequalities import reduce_rational_inequalities
>>> x = Symbol('x', real=True)
>>> reduce_rational_inequalities([[x**2 <= 0]], x)
Eq(x, 0)
>>> reduce_rational_inequalities([[x + 2 > 0]], x)
(-2 < x) & (x < oo)
>>> reduce_rational_inequalities([[(x + 2, ">")]], x)
(-2 < x) & (x < oo)
>>> reduce_rational_inequalities([[x + 2]], x)
Eq(x, -2)
"""
exact = True
eqs = []
solution = S.Reals if exprs else S.EmptySet
for _exprs in exprs:
_eqs = []
for expr in _exprs:
if isinstance(expr, tuple):
expr, rel = expr
else:
if expr.is_Relational:
expr, rel = expr.lhs - expr.rhs, expr.rel_op
else:
expr, rel = expr, '=='
if expr is S.true:
numer, denom, rel = S.Zero, S.One, '=='
elif expr is S.false:
numer, denom, rel = S.One, S.One, '=='
else:
numer, denom = expr.together().as_numer_denom()
try:
(numer, denom), opt = parallel_poly_from_expr(
(numer, denom), gen)
except PolynomialError:
raise PolynomialError(filldedent('''
only polynomials and
rational functions are supported in this context'''))
if not opt.domain.is_Exact:
numer, denom, exact = numer.to_exact(), denom.to_exact(), False
domain = opt.domain.get_exact()
if not (domain.is_ZZ or domain.is_QQ):
expr = numer/denom
expr = Relational(expr, 0, rel)
solution &= solve_univariate_inequality(expr, gen, relational=False)
else:
_eqs.append(((numer, denom), rel))
if _eqs:
eqs.append(_eqs)
if eqs:
solution &= solve_rational_inequalities(eqs)
if not exact:
solution = solution.evalf()
if relational:
solution = solution.as_relational(gen)
return solution
[docs]def reduce_abs_inequality(expr, rel, gen):
"""Reduce an inequality with nested absolute values.
Examples
========
>>> from sympy import Abs, Symbol
>>> from sympy.solvers.inequalities import reduce_abs_inequality
>>> x = Symbol('x', real=True)
>>> reduce_abs_inequality(Abs(x - 5) - 3, '<', x)
(2 < x) & (x < 8)
>>> reduce_abs_inequality(Abs(x + 2)*3 - 13, '<', x)
(-19/3 < x) & (x < 7/3)
See Also
========
reduce_abs_inequalities
"""
if gen.is_real is False:
raise TypeError(filldedent('''
can't solve inequalities with absolute
values containing non-real variables'''))
def _bottom_up_scan(expr):
exprs = []
if expr.is_Add or expr.is_Mul:
op = expr.func
for arg in expr.args:
_exprs = _bottom_up_scan(arg)
if not exprs:
exprs = _exprs
else:
args = []
for expr, conds in exprs:
for _expr, _conds in _exprs:
args.append((op(expr, _expr), conds + _conds))
exprs = args
elif expr.is_Pow:
n = expr.exp
if not n.is_Integer:
raise ValueError("Only Integer Powers are allowed on Abs.")
_exprs = _bottom_up_scan(expr.base)
for expr, conds in _exprs:
exprs.append((expr**n, conds))
elif isinstance(expr, Abs):
_exprs = _bottom_up_scan(expr.args[0])
for expr, conds in _exprs:
exprs.append(( expr, conds + [Ge(expr, 0)]))
exprs.append((-expr, conds + [Lt(expr, 0)]))
else:
exprs = [(expr, [])]
return exprs
exprs = _bottom_up_scan(expr)
mapping = {'<': '>', '<=': '>='}
inequalities = []
for expr, conds in exprs:
if rel not in mapping.keys():
expr = Relational( expr, 0, rel)
else:
expr = Relational(-expr, 0, mapping[rel])
inequalities.append([expr] + conds)
return reduce_rational_inequalities(inequalities, gen)
[docs]def reduce_abs_inequalities(exprs, gen):
"""Reduce a system of inequalities with nested absolute values.
Examples
========
>>> from sympy import Abs, Symbol
>>> from sympy.abc import x
>>> from sympy.solvers.inequalities import reduce_abs_inequalities
>>> x = Symbol('x', real=True)
>>> reduce_abs_inequalities([(Abs(3*x - 5) - 7, '<'),
... (Abs(x + 25) - 13, '>')], x)
(-2/3 < x) & (x < 4) & (((-oo < x) & (x < -38)) | ((-12 < x) & (x < oo)))
>>> reduce_abs_inequalities([(Abs(x - 4) + Abs(3*x - 5) - 7, '<')], x)
(1/2 < x) & (x < 4)
See Also
========
reduce_abs_inequality
"""
return And(*[ reduce_abs_inequality(expr, rel, gen)
for expr, rel in exprs ])
[docs]def solve_univariate_inequality(expr, gen, relational=True, domain=S.Reals, continuous=False):
"""Solves a real univariate inequality.
Parameters
==========
expr : Relational
The target inequality
gen : Symbol
The variable for which the inequality is solved
relational : bool
A Relational type output is expected or not
domain : Set
The domain over which the equation is solved
continuous: bool
True if expr is known to be continuous over the given domain
(and so continuous_domain() doesn't need to be called on it)
Raises
======
NotImplementedError
The solution of the inequality cannot be determined due to limitation
in `solvify`.
Notes
=====
Currently, we cannot solve all the inequalities due to limitations in
`solvify`. Also, the solution returned for trigonometric inequalities
are restricted in its periodic interval.
See Also
========
solvify: solver returning solveset solutions with solve's output API
Examples
========
>>> from sympy.solvers.inequalities import solve_univariate_inequality
>>> from sympy import Symbol, sin, Interval, S
>>> x = Symbol('x')
>>> solve_univariate_inequality(x**2 >= 4, x)
((2 <= x) & (x < oo)) | ((x <= -2) & (-oo < x))
>>> solve_univariate_inequality(x**2 >= 4, x, relational=False)
Union(Interval(-oo, -2), Interval(2, oo))
>>> domain = Interval(0, S.Infinity)
>>> solve_univariate_inequality(x**2 >= 4, x, False, domain)
Interval(2, oo)
>>> solve_univariate_inequality(sin(x) > 0, x, relational=False)
Interval.open(0, pi)
"""
from sympy.calculus.util import (continuous_domain, periodicity,
function_range)
from sympy.solvers.solvers import denoms
from sympy.solvers.solveset import solveset_real, solvify
# This keeps the function independent of the assumptions about `gen`.
# `solveset` makes sure this function is called only when the domain is
# real.
d = Dummy(real=True)
expr = expr.subs(gen, d)
_gen = gen
gen = d
rv = None
if expr is S.true:
rv = domain
elif expr is S.false:
rv = S.EmptySet
else:
e = expr.lhs - expr.rhs
period = periodicity(e, gen)
if period is not None:
frange = function_range(e, gen, domain)
rel = expr.rel_op
if rel == '<' or rel == '<=':
if expr.func(frange.sup, 0):
rv = domain
elif not expr.func(frange.inf, 0):
rv = S.EmptySet
elif rel == '>' or rel == '>=':
if expr.func(frange.inf, 0):
rv = domain
elif not expr.func(frange.sup, 0):
rv = S.EmptySet
inf, sup = domain.inf, domain.sup
if sup - inf is S.Infinity:
domain = Interval(0, period, False, True)
if rv is None:
solns = solvify(e, gen, domain)
if solns is None:
raise NotImplementedError(filldedent('''The inequality cannot be
solved using solve_univariate_inequality.'''))
singularities = []
for d in denoms(expr, gen):
singularities.extend(solvify(d, gen, domain))
if not continuous:
domain = continuous_domain(e, gen, domain)
include_x = expr.func(0, 0)
def valid(x):
v = e.subs(gen, x)
try:
r = expr.func(v, 0)
except TypeError:
r = S.false
if r in (S.true, S.false):
return r
if v.is_real is False:
return S.false
else:
v = v.n(2)
if v.is_comparable:
return expr.func(v, 0)
return S.false
try:
discontinuities = set(domain.boundary -
FiniteSet(domain.inf, domain.sup))
# remove points that are not between inf and sup of domain
critical_points = FiniteSet(*(solns + singularities + list(
discontinuities))).intersection(
Interval(domain.inf, domain.sup,
domain.inf not in domain, domain.sup not in domain))
reals = _nsort(critical_points, separated=True)[0]
except NotImplementedError:
raise NotImplementedError('sorting of these roots is not supported')
sol_sets = [S.EmptySet]
start = domain.inf
if valid(start) and start.is_finite:
sol_sets.append(FiniteSet(start))
for x in reals:
end = x
if valid(_pt(start, end)):
sol_sets.append(Interval(start, end, True, True))
if x in singularities:
singularities.remove(x)
else:
if x in discontinuities:
discontinuities.remove(x)
_valid = valid(x)
else: # it's a solution
_valid = include_x
if _valid:
sol_sets.append(FiniteSet(x))
start = end
end = domain.sup
if valid(end) and end.is_finite:
sol_sets.append(FiniteSet(end))
if valid(_pt(start, end)):
sol_sets.append(Interval.open(start, end))
rv = Union(*sol_sets).subs(gen, _gen)
return rv if not relational else rv.as_relational(_gen)
def _pt(start, end):
"""Return a point between start and end"""
if start.is_infinite and end.is_infinite:
pt = S.Zero
elif end.is_infinite:
pt = start + 1
elif start.is_infinite:
pt = end - 1
else:
pt = (start + end)/2
return pt
def _solve_inequality(ie, s):
""" A hacky replacement for solve, since the latter only works for
univariate inequalities. """
expr = ie.lhs - ie.rhs
try:
p = Poly(expr, s)
if p.degree() != 1:
raise NotImplementedError
except (PolynomialError, NotImplementedError):
try:
return reduce_rational_inequalities([[ie]], s)
except PolynomialError:
return solve_univariate_inequality(ie, s)
a, b = p.all_coeffs()
if a.is_positive or ie.rel_op in ('!=', '=='):
return ie.func(s, -b/a)
elif a.is_negative:
return ie.reversed.func(s, -b/a)
else:
raise NotImplementedError
def _reduce_inequalities(inequalities, symbols):
# helper for reduce_inequalities
poly_part, abs_part = {}, {}
other = []
for inequality in inequalities:
expr, rel = inequality.lhs, inequality.rel_op # rhs is 0
# check for gens using atoms which is more strict than free_symbols to
# guard against EX domain which won't be handled by
# reduce_rational_inequalities
gens = expr.atoms(Symbol)
if len(gens) == 1:
gen = gens.pop()
else:
common = expr.free_symbols & symbols
if len(common) == 1:
gen = common.pop()
other.append(_solve_inequality(Relational(expr, 0, rel), gen))
continue
else:
raise NotImplementedError(filldedent('''
inequality has more than one
symbol of interest'''))
if expr.is_polynomial(gen):
poly_part.setdefault(gen, []).append((expr, rel))
else:
components = expr.find(lambda u:
u.has(gen) and (
u.is_Function or u.is_Pow and not u.exp.is_Integer))
if components and all(isinstance(i, Abs) for i in components):
abs_part.setdefault(gen, []).append((expr, rel))
else:
other.append(_solve_inequality(Relational(expr, 0, rel), gen))
poly_reduced = []
abs_reduced = []
for gen, exprs in poly_part.items():
poly_reduced.append(reduce_rational_inequalities([exprs], gen))
for gen, exprs in abs_part.items():
abs_reduced.append(reduce_abs_inequalities(exprs, gen))
return And(*(poly_reduced + abs_reduced + other))
[docs]def reduce_inequalities(inequalities, symbols=[]):
"""Reduce a system of inequalities with rational coefficients.
Examples
========
>>> from sympy import sympify as S, Symbol
>>> from sympy.abc import x, y
>>> from sympy.solvers.inequalities import reduce_inequalities
>>> reduce_inequalities(0 <= x + 3, [])
(-3 <= x) & (x < oo)
>>> reduce_inequalities(0 <= x + y*2 - 1, [x])
x >= -2*y + 1
"""
if not iterable(inequalities):
inequalities = [inequalities]
inequalities = [sympify(i) for i in inequalities]
gens = set().union(*[i.free_symbols for i in inequalities])
if not iterable(symbols):
symbols = [symbols]
symbols = (set(symbols) or gens) & gens
if any(i.is_real is False for i in symbols):
raise TypeError(filldedent('''
inequalities cannot contain symbols that are not real.'''))
# make vanilla symbol real
recast = dict([(i, Dummy(i.name, real=True))
for i in gens if i.is_real is None])
inequalities = [i.xreplace(recast) for i in inequalities]
symbols = {i.xreplace(recast) for i in symbols}
# prefilter
keep = []
for i in inequalities:
if isinstance(i, Relational):
i = i.func(i.lhs.as_expr() - i.rhs.as_expr(), 0)
elif i not in (True, False):
i = Eq(i, 0)
if i == True:
continue
elif i == False:
return S.false
if i.lhs.is_number:
raise NotImplementedError(
"could not determine truth value of %s" % i)
keep.append(i)
inequalities = keep
del keep
# solve system
rv = _reduce_inequalities(inequalities, symbols)
# restore original symbols and return
return rv.xreplace({v: k for k, v in recast.items()})