/*
 * Lambert W function, i.e. the inverse of f(t) = t e^t.
 *
 * This function allows you to solve several related kinds of
 * equation, by composing it with assorted operations:
 *
 * If      t e^t = x      then t = W(x)
 * If    e^t / t = x      then t = -W(-1/x)
 * If    e^t + t = x      then t = log(W(exp(x)))
 * If    e^t - t = x      then t = log(-W(-exp(-x)))
 * If    t log t = x      then t = exp(W(x))
 * If  t / log t = x      then t = exp(-W(-1/x))
 * If  t + log t = x      then t = W(exp(x))
 * If  t - log t = x      then t = -W(-exp(-x))
 * If        t^t = x      then t = exp(W(log x))
 * If        x^t = t      then t = exp(-W(-log(x)))
 *
 * Since W has two branches (both starting at W(-1/e) = -1, but one
 * heads upwards and is valid for all x >= -1/e while the other heads
 * downwards and is only valid for -1/e <= x < 0), most of the above
 * types of equation can have different numbers of solutions depending
 * on whether the value passed to W is less than -1/e, greater than 0,
 * or in between the two.
 */

#include <stdio.h>
#include <stdlib.h>
#include <assert.h>

#include "spigot.h"
#include "error.h"
#include "funcs.h"

struct XexConstructor : MonotoneConstructor {
    MonotoneConstructor *clone() { return new XexConstructor(); }
    Spigot *construct(const bigint &n, const bigint &d) {
        return spigot_mul(spigot_rational(n, d),
                          spigot_exp(spigot_rational(n, d)));
    }
};

Spigot *spigot_lambertw_pos(Spigot *x)
{
    XexConstructor xex;

    /*
     * This is the positive branch of W, i.e. the one that starts from
     * W(-1/e) = -1 and goes upwards through the origin into the
     * positive x,y quadrant.
     *
     * We need to find two numbers such that f(lo) < x < f(hi). Since
     * this branch of W is basically log-like, there's no need to
     * worry too much about getting these well centred - it's not as
     * if they're going to be out by tens of orders of magnitude so
     * that we waste a lot of time at the start.
     *
     * Our upper bound will simply be (a good enough approximation to)
     * log(1+x), which is >= W(x) everywhere, with equality when x=0
     * (which we therefore have to spot in advance and deal with).
     *
     * (Proof that log(1+x) >= W(x): taking f of both sides (which is
     * monotonic in the relevant range) we see that this is true iff
     * (1+x)log(1+x) - x >= 0, and the derivative of the LHS is just
     * log(1+x), which is negative for x<0 and positive for x>0, so
     * there is one global minimum at x=0 which is the equality case.)
     *
     * And our lower bound, more trivially still, is just -1.
     */

    bigint n, d;
    if (x->is_rational(&n, &d) && n == 0)
        return spigot_integer(0);

    /*
     * Check the lower bound.
     *
     * We don't want to do this _up front_ with complete precision,
     * because it would be an exactness hazard. W is actually defined
     * at -1/e itself, so we'd like an expression of the form W(-1/e)
     * + stuff to evaluate the same as (-1) + stuff without hanging at
     * startup due to the bounds check in W.
     *
     * So we do a quick check for really obvious cases using an
     * approximant, and then defer the precise bounds check using
     * spigot_enforce. (Hence it might come up part way through
     * output, if for example the input to W is a stream of digits
     * from an input file which _turns out_ part way along to be just
     * less than -1/e.)
     */
    {
        StaticGenerator test(spigot_sub(xex.construct(-1, 1), x->clone()));
        bigint approx = test.get_approximate_approximant(64);
        if (approx > 1)                // immediately obviously out of range
            throw spigot_error("W of less than -1/e");
        else if (approx > -2)          // _might_ be out of range
            x = spigot_enforce(x, ENFORCE_GE, xex.construct(-1, 1),
                               spigot_error("W of less than -1/e"));
    }

    bigint nhi, dhi;
    {
        BracketingGenerator boundgen(spigot_log1p(x->clone()));
        while (1) {
            bigint n1, n2;
            boundgen.get_bracket(&n1, &n2, &dhi);
            // dprint("upper bound: trying (%b, %b) / %b", &n1, &n2, &dhi);

            int s = parallel_sign_test
                (spigot_sub(xex.construct(n1, dhi), x->clone()),
                 spigot_sub(xex.construct(n2, dhi), x->clone()));

            // dprint("s = %d", s);

            if (s == +1) {
                nhi = n1;
                break;
            } else if (s == +2) {
                nhi = n2;
                break;
            }
        }
    }

    // dprint("upper bound %b/%b", &nhi, &dhi);

    Spigot *ret = spigot_monotone_invert(new XexConstructor, true,
                                         -dhi, nhi, dhi, x);
    return ret;
}

Spigot *spigot_lambertw_neg(Spigot *x)
{
    XexConstructor xex;

    /*
     * This is the negative branch of W, i.e. the one that starts from
     * W(-1/e) = -1, heads downwards, whizzes out to -infinity as x
     * approaches zero, and has no value at all for positive x.
     */

    /*
     * Bounds checking for sensible error messages.
     */
    {
        StaticGenerator test(x->clone());
        if (test.get_sign() >= 0)
            throw spigot_error("Wn of a non-negative number");
    }

    /*
     * Lower-bound check, same as the version above.
     */
    {
        StaticGenerator test(spigot_sub(xex.construct(-1, 1), x->clone()));
        bigint approx = test.get_approximate_approximant(64);
        if (approx > 1)                // immediately obviously out of range
            throw spigot_error("Wn of less than -1/e");
        else if (approx > -2)          // _might_ be out of range
            x = spigot_enforce(x, ENFORCE_GE, xex.construct(-1, 1),
                               spigot_error("Wn of less than -1/e"));
    }

    /*
     * Find some bounds.
     *
     * We have -1/e <= x < 0 (probably - unless the deferred bounds
     * check above is going to kick in later). We seek t = Wn(x) < -1,
     * or rather, we seek two values of t such that f(t_0) < x and
     * f(t_1) > x.
     *
     * We have f(log(-x)) = log(-x) e^(log(-x)) = x (-log(-x)) < x
     * (since (-log(-x)) >= 1 given the range of x, and multiplying
     * the negative number x by more than 1 makes it smaller still).
     * So t_0 = log(-x) will do, although we'll just use t_0 = -1 if
     * it looks dangerously close to that.
     *
     * Also have f(2 log(-x)) = 2 log(-x) e^(2 log(-x)) = 2x^2 log(-x)
     * = x (2x log -x), and since 0 < 2x log -x < 2/e given the range
     * of x, this is multiplying the negative x by a positive number
     * _less_ than 1, which makes it larger. So t_1 = 2 log(-x) will
     * also do, and we've got only a factor of two separating our
     * bounds.
     */

    // FIXME: we'll have to special-case one of these bounds to -1, I
    // think, to avoid a hazard when trying to find something between
    // it and -1/e.

    bigint nhi, dhi;
    {
        BracketingGenerator boundgen
            (spigot_rational_mul(spigot_log(spigot_neg(x->clone())), 2, 1));
        while (1) {
            bigint n1, n2;
            boundgen.get_bracket(&n1, &n2, &dhi);
            // dprint("upper bound: trying (%b, %b) / %b", &n1, &n2, &dhi);

            int s = parallel_sign_test
                (spigot_sub(xex.construct(n1, dhi), x->clone()),
                 spigot_sub(xex.construct(n2, dhi), x->clone()));

            // dprint("s = %d", s);

            if (s == +1) {
                nhi = n1;
                break;
            } else if (s == +2) {
                nhi = n2;
                break;
            }
        }
    }

    bigint nlo, dlo;
    {
        BracketingGenerator boundgen
            (spigot_log(spigot_neg(x->clone())));
        while (1) {
            bigint n1, n2;
            boundgen.get_bracket(&n1, &n2, &dlo);
            if (n1 > -2*dlo) {
                // dprint("lower bound: got (%b, %b) / %b; just use -1",
                //        &n1, &n2, &dlo);
                nlo = -1;
                dlo = +1;
                break;
            }

            // dprint("lower bound: trying (%b, %b) / %b", &n1, &n2, &dlo);

            int s = parallel_sign_test
                (spigot_sub(xex.construct(n1, dlo), x->clone()),
                 spigot_sub(xex.construct(n2, dlo), x->clone()));

            // dprint("s = %d", s);

            if (s == -1) {
                nlo = n1;
                break;
            } else if (s == -2) {
                nlo = n2;
                break;
            }
        }
    }

    // dprint("bounds (%b/%b, %b/%b)", &nlo, &dlo, &nhi, &dhi);

    Spigot *ret = spigot_monotone_invert(new XexConstructor, false,
                                         nhi*dlo, nlo*dhi, dhi*dlo, x);
    return ret;
}
