from sympy import *
from matplotlib import pyplot as plt
from scipy import integrate as sp_int
from scipy import optimize
import numpy as np

"""
Plot a sympy expression f(x_sym) with respect to variable x_sym.
x_sym in range lim = [low, high].
N is the number of sample points.
"""
def plot_expr(ax, f, x_sym, lim, N):
    F = lambdify(x_sym, f)
    x = []
    y = []
    for i in range(N):
        xi = i * (lim[1]-lim[0]) / (N-1) + lim[0]
        x.append(xi)
        y.append(F(xi))
    ax.plot(x, y)


"""
Derive an analytical expression for the wetted perimeter of the river.
The width of the river at height y above the riverbed is given by w, which is
a sympy expression in terms of y. A_sym is a symbol representing the area filled
with water. The expression returned from this function will be in terms of A_sym.
"""
def derive_wetted_perim(w, y, A_sym, debug=False):
    # Define symbols and functions
    l = Function("l", positive=True, real=True)(A_sym)

    h_sym = h = Symbol("h", positive=True, real=True)

    # Calculate length by integrating segment lengths between 0 and h.
    # Add the width at the base to the total length for the case of a flat bottom.
    dx_dy = diff(w/2, y) # x = +- w/2
    dl_dy = 2 * sqrt(dx_dy**2 + 1) # 2* for both left and right side of channel
    l = integrate(dl_dy, (y, 0, h)) + w.subs(y, 0)
    if debug:
        print("l(h) =", l)

    # Calculate channel area by integrating width wrt vertical coordinate.
    A = simplify(integrate(w, (y, 0, h)))
    if debug:
        print("A(h) =", A)

    # Get h in terms of A (there may be multiple solutions)
    h_of_A_arr = solve(A_sym - A, h)

    # For each solution...
    for h_of_A in h_of_A_arr:
        if debug:
            print("================")
            print("h(A) =")
            pprint(h_of_A)

        # Substitute the expression for h(A) into l(h) to get l(A)
        l_of_A = simplify(l.subs(h, h_of_A))
        if debug:
            print("l(A) =")
            pprint(l_of_A)

        return l_of_A

"""
Returns the rate of change of c using Godunov method, for use in an ODE solver.
Q is a python function taking c (i.e. A in our case) as a parameter and calculating
the flux. delta is the distance timestep along the length of the river. c an array
containing the current conservation law parameter at each unit length. c_inflow is a
scalar giving the inflow at the left side of the river.
"""
def godunov_step(Q, delta, c, c_inflow):
    N = len(c)
    dc_dt = [-(Q(c[0]) - Q(c_inflow))/delta]
    for i in range(1, N):
        dc_dt.append(-(Q(c[i]) - Q(c[i-1]))/delta)
    return dc_dt

"""
Solves a conservation problem using the Godunov method, returning a list-of-lists, where
each list is the value of c for all x-steps, at a different time.
Q is a python function taking c and calculating the flux. c_init is a list containing
the initial conservation variable at each length step. c_inflow is the value of c on
the left. delta is the x-step. t is a list giving the time values for which the ODE
solution should be given.
"""
def godunov_solve(Q, c_init, c_inflow, delta, t):
    gstep = lambda c, t: godunov_step(Q, delta, c, c_inflow)
    return sp_int.odeint(gstep, c_init, t)

"""
Division that can handle x/0
"""
def safe_div(a,b):
    if abs(b) < 1e-9:
        if b < 0:
            return -a/1e-9
        else:
            return a/1e-9
    else:
        return a/b

"""
Takes a single timestep calculating the propagation of a shock, returning the rate of
change of the current state vector y.
y is the current parameter vector: [x, x0L, x0R], where x is the current x-position
of the shock, x0L is the initial (t=0) x-value of the characteristic meeting the
shock from the left, and x0R is the same but for the right side.
t is the current time.
vs is a python function giving the shock speed given the value
of c to the left and right of the shock (c_L and c_R). vc is a python function giving the
characteristic velocity given c. dvc_dc is a python function giving the rate of change of
vc with respect to c, for a given c. c_init is a python function giving the initial value
of c for a given x. dc_dx0 is a python function giving the rate of change of c_init with
respect to x.
"""
def shock_step(y, t, vs, vc, dvc_dc, c_init, dc_dx0):
    x = y[0]
    x0L = y[1]
    x0R = y[2]

    c_L = c_init(x0L)
    c_R = c_init(x0R)
    
    dx_dt = vs(c_L, c_R)

    vc_L = vc(c_L)
    vc_R = vc(c_R)

    dcL_dx0 = dc_dx0(x0L)
    dcR_dx0 = dc_dx0(x0R)

    dx0L_dt = safe_div((vc_L - dx_dt), (-1 - t*dvc_dc(c_L)*dcL_dx0))
    dx0R_dt = safe_div((vc_R - dx_dt), (-1 - t*dvc_dc(c_R)*dcR_dx0))
    
    dcL_dt = dcL_dx0 * dx0L_dt
    dcR_dt = dcR_dx0 * dx0R_dt

    ret = [dx_dt, dx0L_dt, dx0R_dt]
    #print(ret)
    return ret

"""
Finds a solution to the shock position.
Q_expr is a sympy expression giving the flux in terms of c_sym.
c_init is a sympy expression giving the initial c in terms of x_sym.
x_shock is the x-coordinate of the start of the shock.
t_shock is the time of the start of the shock.
x0_shock is the x-coordinate whose characteristic leads into the start of the shock.
delta_x is the x simulation step and L is the total length.
delta_t is the t simulation step and T is the total time.
# (x0, t0) is the start of the shock
"""
def solve_shock(Q_expr, c_sym, c_init, x_sym, t_shock, x_shock, x0_shock,
                delta_x, L, delta_t, T):
    Q = lambdify(c_sym, Q_expr)

    # Get function for char. velocity
    vc_expr = diff(Q_expr, c_sym)
    vc = lambdify(c_sym, vc_expr)
    
    # Get function for shock speed (addition of 1e-9 prevents 0/0 error)
    vs = lambda c_L,c_R: (Q(c_L) - Q(c_R) + vc(c_L)*1e-9) / (c_L - c_R + 1e-9)

    

    dvc_dc = lambdify(c_sym, diff(vc_expr, c_sym))
    dc_dx0 = lambdify(x_sym, diff(c_init, x_sym))
    c_init_lamb = lambdify(x_sym, c_init)
    t_eval = [t_shock+i*delta_t for i in range(int((T-t_shock)/delta_t))]

    # Solve shock equation. Start with x=x0, x0L,x0R=x0_shock -+ epsillon
    # x0L and x0R must start slightly off from x0_shock, otherwise they don't change.
    sln = sp_int.solve_ivp(lambda t,y: shock_step(y, t, vs, vc,
                                                  dvc_dc, c_init_lamb, dc_dx0),
                           [t_shock, T],
                           [x_shock, x0_shock-1e-3, x0_shock+1e-3],
                           dense_output=True, atol=1e-3)
    return sln.sol

"""
Get the global argmin of a function
"""
def get_argmin(f, x_sym, bounds):
    f_lamb = lambdify(x_sym, f)
    res = optimize.shgo(f_lamb, [bounds])
    return res.x[0]

def get_all_argmin(f, x_sym, bounds):
    f_lamb = lambdify(x_sym, f)
    res = optimize.shgo(f_lamb, [bounds], iters=5,
                        sampling_method='sobol')
    n = res.xl.shape[0]
    minima = []
    for i in range(n):
        x = res.xl[i, 0]
        if x > bounds[0]+1e-3 and x < bounds[1]-1e-3:
            minima.append(x)
    #print(minima)
    return minima

def solve_characteristics(Q, c_sym, c_init, x_sym, delta_x, L, delta_t, T):
    # Get expression for char. velocity
    vc = diff(Q, c_sym)

    #vc_diff = diff(vc, c_sym)
    #plot_expr(vc_diff, c_sym, [0, 1000], 1000)
    #exit(0)

    # Derivative of vc(c_init(x)) wrt x
    vc_diff = diff(vc, c_sym).subs(c_sym, c_init) * diff(c_init, x_sym)

    # Find all the shock start points
    x0_shock_arr = get_all_argmin(vc_diff, x_sym, (0, L))
    shocks = []

    print("Solving shocks... ", end="", flush=True)
    
    # Solve each shock
    for x0_shock in x0_shock_arr:
        vc_shock = vc.subs(c_sym, c_init.subs(x_sym, x0_shock).evalf()).evalf()
        t_shock = -1/(vc_diff.subs(x_sym, x0_shock).evalf())
        x_shock = x0_shock + t_shock * vc_shock

        shock_sln = solve_shock(Q, c_sym, c_init, x_sym, t_shock, x_shock, x0_shock,
                                delta_x, L, delta_t, T)
        shocks.append({"t": t_shock, "x": x_shock, "x0": x0_shock, "sln": shock_sln})

    print("done")

    print(f"Shock starts at ({shocks[0]['x']}, {shocks[0]['t']})")
    
    c_of_x0 = lambdify(x_sym, c_init)
    vc_of_c = lambdify(c_sym, vc)

    chars = []

    print("Generating characteristics... [", end="", flush=True)
    
    # Generate characteristics starting from uniform x-ordinates
    ch_spacing = int(L/25)
    for j in range(int(L/ch_spacing)): # For each characteristic
        x0 = x = ch_spacing*j

        # Empty lists to store x and t values
        x_ch = [x0, x0]
        t_ch = [0, 0]

        # Value of c and vc along this characteristic
        c_j = c_of_x0(x0)
        vc_j = vc_of_c(c_j)
        
        for i in range(int(T/delta_t)):
            t = delta_t * i
            x = x + vc_j*delta_t

            hit_shock = False
            
            for s in shocks:
                # The shock solution is indexed in time from the start of the shock,
                # so adjust the index accordingly
                shock_idx = int(i - s["t"]/delta_t)
                
                # If the char. crosses the shock, stop
                if (t > s["t"]) and (
                        ((x0 > s["x0"]-1e-1) and (x < s["sln"](t)[0]+1e-1)) or
                        ((x0 < s["x0"]+1e-1) and (x > s["sln"](t)[0]-1e-1))):
                    hit_shock = True
                    break

            if hit_shock:
                break
            
            x_ch[-1] = x
            t_ch[-1] = t

        # Store
        chars.append({"x": x_ch, "t": t_ch})
        if (int(20*len(chars)/int(L/ch_spacing)) >
            int(20*(len(chars)-1)/int(L/ch_spacing))):
            print("=", end="", flush=True)

    print("] done")
    
    return chars, shocks

def plot_shocks(shocks, delta_t, T, ax):
    print("Plotting shocks... ", end="", flush=True)

    # Plot the shock solutions
    for s in shocks:
        if s["t"] < T:
            t_arr = [i*delta_t+s["t"] for i in range(int((T-s["t"])/delta_t))]
            y = s["sln"](np.array(t_arr))[0,:] #[s["sln"](t)[0] for t in t_arr]
            ax.plot(y, t_arr, color="blue")

    print("done")

def plot_chars(chars, ax):
    print("Plotting characteristics... ", end="", flush=True)
    for ch in chars:
        ax.plot(ch["x"], ch["t"], color="red")
    print("done")
    
def gen_fig(prefix, g_val, alpha_val, f_val, w_expr, y_sym, a_sym, a_val,
            delta_x, L, delta_t, T,
            A_init, x_sym):
    g, alpha, f = symbols("g,alpha,f")
    
    # Width of channel wrt vertical coordinate (y)
    y = y_sym #Symbol("y", positive=True, real=True)
    a = Symbol("a", positive=True, real=True)
    w = w_expr #a #sqrt(y/a)

    plt.figure()
    ax = plt.gca()
    plot_expr(ax, A_init, x_sym, [0, L], int(L/delta_x))
    plt.xlabel("Position (m)")
    plt.ylabel("Initial Area (m^2)")
    plt.savefig(prefix+"-init.pdf", bbox_inches="tight")
    
    A_sym = Symbol("A", positive=True, real=True)

    l_of_A = derive_wetted_perim(w, y, A_sym)

    u = simplify(sqrt(A_sym*g*sin(alpha) / (f * l_of_A)))
    
    Q = (A_sym*u).subs([("g", g_val), ("alpha", alpha_val), ("f", f_val), (a, a_val)])

    chars, shocks = solve_characteristics(Q, A_sym, A_init,
                                          x_sym, delta_x, L, delta_t, T)

    plt.figure()
    ax0 = plt.gca()

    plot_chars(chars, ax0)
    plot_shocks(shocks, delta_t, T, ax0)
    ax0.axis([0, L, 0, T])
    plt.xlabel("Position (m)")
    plt.ylabel("Time (s)")
    plt.savefig(prefix+"-shock.pdf", bbox_inches="tight")

    plt.figure()
    ax1 = plt.gca()
    
    #plot_expr(diff(Q, A_sym), A_sym, [0,30], 1000)
    #return
    Q_lamb = lambdify(A_sym, Q)

    x = [i*delta_x for i in range(int(L/delta_x))]

    A_inflow = float(A_init.evalf(subs={x_sym: 0}))
    A_init_lamb = lambdify(x_sym, A_init)
    A_init_data = [A_init_lamb(xi) for xi in x]

    t = [i*delta_t for i in range(int(T/delta_t))]

    print("Solving Godunov... ", end="", flush=True)
    c_sln = godunov_solve(Q_lamb, A_init_data, A_inflow, delta_x, t)
    print("done")
    
    n_xticks = 6
    n_tticks = 6
    x_ticks = [int(i*L/delta_x/(n_xticks-1)) for i in range(n_xticks)]
    x_tick_labels = [delta_x*i for i in x_ticks]
    t_ticks = [int(i*T/delta_t/(n_tticks-1)) for i in range(n_tticks)]
    t_tick_labels = [delta_t*i for i in t_ticks]

    print("Plotting Godunov... ", end="", flush=True)
    
    plt.colorbar(plt.pcolor(c_sln))
    ax1.imshow(c_sln, origin="lower", aspect="auto")
    ax1.set_xticks(x_ticks, x_tick_labels)
    ax1.set_yticks(t_ticks, t_tick_labels)
    plt.xlabel("Position (m)")
    plt.ylabel("Time (s)")
    plt.savefig(prefix+"-godunov.pdf", bbox_inches="tight")

    print("done")


def main():
    y_sym = Symbol("y", positive=True, real=True)
    x_sym = Symbol("x", positive=True, real=True)
    a_sym = Symbol("a", positive=True, real=True)

    gen_A_init = lambda A0,A1,s_slope,s_grad: (A0+A1)/2 - (A0-A1)/2*tanh((x_sym-s_slope)*s_grad)
    
    plot_defaults = {
        "g_val": 9.81, "alpha_val": 0.01, "f_val": 0.05, "w_expr": a_sym,
        "y_sym": y_sym, "a_sym": a_sym, "a_val": 50,
        "delta_x": 1, "L": 300, "delta_t": 1, "T": 200,
        "A_init": gen_A_init(20,1,100,1/50),
        "x_sym": x_sym
    }

    plots = [
        #plot_defaults | { "prefix": "fig/baseline" },
        #plot_defaults | { "prefix": "alpha-0.02", "alpha_val": 0.02 },
        #plot_defaults | { "prefix": "f-0.1", "f_val": 0.1 },
        #plot_defaults | { "prefix": "dx_0.25,dt_0.25", "delta_x": 0.25, "delta_t": 0.25 },
        #plot_defaults | { "prefix": "fig/a-5", "a_val": 5 },
        #plot_defaults | { "prefix": "fig/a-3", "a_val": 3, "L": 700, "T": 400 },
        #plot_defaults | { "prefix": "fig/a-10", "a_val": 10 },
        #plot_defaults | { "prefix": "fig/a-1", "a_val": 1 },
        #plot_defaults | { "prefix": "fig/slope-0.1", "A_init": gen_A_init(20,1,100,0.1) },
        #plot_defaults | { "prefix": "fig/A-pulse",  "A_init": 1 + 19*1000/(1*1000+(x_sym-100)**2) },
        plot_defaults | { "prefix": "fig/A-pulse-short",  "A_init": 1 + 19*10/(1*10+(x_sym-100)**2) }
    ]

    plt.rcParams.update({'font.size': 22})

    for p in plots:
        gen_fig(**p)

main()