diff --git a/flood.py b/flood.py
index ff0541ae058fe16cef0ca13f66e69be76eb43554..75c3b23d5a1b124fce4fa6f4f9cfd7a82dcfb846 100644
--- a/flood.py
+++ b/flood.py
@@ -1,8 +1,14 @@
 from sympy import *
 from matplotlib import pyplot as plt
 from scipy import integrate as sp_int
+from scipy import optimize
 import numpy as np
 
+"""
+Plot a sympy expression f(x_sym) with respect to variable x_sym.
+x_sym in range lim = [low, high].
+N is the number of sample points.
+"""
 def plot_expr(f, x_sym, lim, N):
     F = lambdify(x_sym, f)
     x = []
@@ -14,15 +20,13 @@ def plot_expr(f, x_sym, lim, N):
     plt.plot(x, y)
     plt.show()
 
-def godunov_step(Q, delta, c, c_inflow):
-    N = len(c)
-    dc_dt = [-(Q(c[0]) - Q(c_inflow))/delta]
-    for i in range(1, N):
-        dc_dt.append(-(Q(c[i]) - Q(c[i-1]))/delta)
-    return dc_dt
 
-
-    
+"""
+Derive an analytical expression for the wetted perimeter of the river.
+The width of the river at height y above the riverbed is given by w, which is
+a sympy expression in terms of y. A_sym is a symbol representing the area filled
+with water. The expression returned from this function will be in terms of A_sym.
+"""
 def derive_wetted_perim(w, y, A_sym, debug=False):
     # Define symbols and functions
     l = Function("l", positive=True, real=True)(A_sym)
@@ -60,39 +64,223 @@ def derive_wetted_perim(w, y, A_sym, debug=False):
 
         return l_of_A
 
-def shock_step(y, t, vs, vc, dvc_dc, dc_dx0):
+"""
+Returns the rate of change of c using Godunov method, for use in an ODE solver.
+Q is a python function taking c (i.e. A in our case) as a parameter and calculating
+the flux. delta is the distance timestep along the length of the river. c an array
+containing the current conservation law parameter at each unit length. c_inflow is a
+scalar giving the inflow at the left side of the river.
+"""
+def godunov_step(Q, delta, c, c_inflow):
+    N = len(c)
+    dc_dt = [-(Q(c[0]) - Q(c_inflow))/delta]
+    for i in range(1, N):
+        dc_dt.append(-(Q(c[i]) - Q(c[i-1]))/delta)
+    return dc_dt
+
+"""
+Solves a conservation problem using the Godunov method, returning a list-of-lists, where
+each list is the value of c for all x-steps, at a different time.
+Q is a python function taking c and calculating the flux. c_init is a list containing
+the initial conservation variable at each length step. c_inflow is the value of c on
+the left. delta is the x-step. t is a list giving the time values for which the ODE
+solution should be given.
+"""
+def godunov_solve(Q, c_init, c_inflow, delta, t):
+    gstep = lambda c, t: godunov_step(Q, delta, c, c_inflow)
+    return sp_int.odeint(gstep, c_init, t)
+
+"""
+Division that can handle x/0
+"""
+def safe_div(a,b):
+    if abs(b) < 1e-9:
+        if b < 0:
+            return -a/1e-9
+        else:
+            return a/1e-9
+    else:
+        return a/b
+
+"""
+Takes a single timestep calculating the propagation of a shock, returning the rate of
+change of the current state vector y.
+y is the current parameter vector: [x, x0L, x0R], where x is the current x-position
+of the shock, x0L is the initial (t=0) x-value of the characteristic meeting the
+shock from the left, and x0R is the same but for the right side.
+t is the current time.
+vs is a python function giving the shock speed given the value
+of c to the left and right of the shock (c_L and c_R). vc is a python function giving the
+characteristic velocity given c. dvc_dc is a python function giving the rate of change of
+vc with respect to c, for a given c. c_init is a python function giving the initial value
+of c for a given x. dc_dx0 is a python function giving the rate of change of c_init with
+respect to x.
+"""
+def shock_step(y, t, vs, vc, dvc_dc, c_init, dc_dx0):
     x = y[0]
-    c_L = y[1]
-    c_R = y[2]
+    x0L = y[1]
+    x0R = y[2]
+
+    c_L = c_init(x0L)
+    c_R = c_init(x0R)
     
     dx_dt = vs(c_L, c_R)
 
     vc_L = vc(c_L)
     vc_R = vc(c_R)
-    x0L = x - t*vc_L
-    x0R = x - t*vc_R
+
     dcL_dx0 = dc_dx0(x0L)
     dcR_dx0 = dc_dx0(x0R)
 
-    dcL_dt = dcL_dx0 * (vc_L - dx_dt) / (1 - t*dvc_dc(c_L)*dcL_dx0)
-    dcR_dt = dcR_dx0 * (vc_R - dx_dt) / (1 - t*dvc_dc(c_R)*dcR_dx0)
+    dx0L_dt = safe_div((vc_L - dx_dt), (-1 - t*dvc_dc(c_L)*dcL_dx0))
+    dx0R_dt = safe_div((vc_R - dx_dt), (-1 - t*dvc_dc(c_R)*dcR_dx0))
+    
+    dcL_dt = dcL_dx0 * dx0L_dt
+    dcR_dt = dcR_dx0 * dx0R_dt
+
+    ret = [dx_dt, dx0L_dt, dx0R_dt]
+    #print(ret)
+    return ret
+
+"""
+Finds a solution to the shock position.
+Q_expr is a sympy expression giving the flux in terms of c_sym.
+c_init is a sympy expression giving the initial c in terms of x_sym.
+x_shock is the x-coordinate of the start of the shock.
+t_shock is the time of the start of the shock.
+x0_shock is the x-coordinate whose characteristic leads into the start of the shock.
+delta_x is the x simulation step and L is the total length.
+delta_t is the t simulation step and T is the total time.
+# (x0, t0) is the start of the shock
+"""
+def solve_shock(Q_expr, c_sym, c_init, x_sym, t_shock, x_shock, x0_shock,
+                delta_x, L, delta_t, T):
+    Q = lambdify(c_sym, Q_expr)
+
+    # Get function for char. velocity
+    vc_expr = diff(Q_expr, c_sym)
+    vc = lambdify(c_sym, vc_expr)
+    
+    # Get function for shock speed (addition of 1e-9 prevents 0/0 error)
+    vs = lambda c_L,c_R: (Q(c_L) - Q(c_R) + vc(c_L)*1e-9) / (c_L - c_R + 1e-9)
 
-    return [dx_dt, dcL_dt, dcR_dt]
     
-def solve_shock(Q, c_init, x0, t0):
-    pass 
+
+    dvc_dc = lambdify(c_sym, diff(vc_expr, c_sym))
+    dc_dx0 = lambdify(x_sym, diff(c_init, x_sym))
+    c_init_lamb = lambdify(x_sym, c_init)
+    t_eval = [t_shock+i*delta_t for i in range(int((T-t_shock)/delta_t))]
+
+    # Solve shock equation. Start with x=x0, x0L,x0R=x0_shock -+ epsillon
+    # x0L and x0R must start slightly off from x0_shock, otherwise they don't change.
+    sln = sp_int.solve_ivp(lambda t,y: shock_step(y, t, vs, vc,
+                                                  dvc_dc, c_init_lamb, dc_dx0),
+                           [t_shock, T],
+                           [x_shock, x0_shock-1e-3, x0_shock+1e-3],
+                           dense_output=True, atol=1e-3)
+    return sln.sol
+
+"""
+Get the global argmin of a function
+"""
+def get_argmin(f, x_sym, bounds):
+    f_lamb = lambdify(x_sym, f)
+    res = optimize.shgo(f_lamb, [bounds])
+    return res.x[0]
+
+def get_all_argmin(f, x_sym, bounds):
+    f_lamb = lambdify(x_sym, f)
+    res = optimize.shgo(f_lamb, [bounds], iters=5,
+                        sampling_method='sobol')
+    n = res.xl.shape[0]
+    minima = []
+    for i in range(n):
+        x = res.xl[i, 0]
+        if x > bounds[0]+1e-3 and x < bounds[1]-1e-3:
+            minima.append(x)
+    #print(minima)
+    return minima
+
+def solve_characteristics(Q, c_sym, c_init, x_sym, delta_x, L, delta_t, T):
+    # Get expression for char. velocity
+    vc = diff(Q, c_sym)
+
+    # Derivative of vc(c_init(x)) wrt x
+    vc_diff = diff(vc, c_sym).subs(c_sym, c_init) * diff(c_init, x_sym)
+
+    # Find all the shock start points
+    x0_shock_arr = get_all_argmin(vc_diff, x_sym, (0, L))
+    shocks = []
+
+    # Solve each shock
+    for x0_shock in x0_shock_arr:
+        vc_shock = vc.subs(c_sym, c_init.subs(x_sym, x0_shock).evalf()).evalf()
+        t_shock = -1/(vc_diff.subs(x_sym, x0_shock).evalf())
+        x_shock = x0_shock + t_shock * vc_shock
+
+        shock_sln = solve_shock(Q, c_sym, c_init, x_sym, t_shock, x_shock, x0_shock,
+                                delta_x, L, delta_t, T)
+        shocks.append({"t": t_shock, "x": x_shock, "x0": x0_shock, "sln": shock_sln})
+        
+
+    c_of_x0 = lambdify(x_sym, c_init)
+    vc_of_c = lambdify(c_sym, vc)
+
+    chars = []
     
-def solve_characteristics(Q, c_init, delta_x, delta_t, T):
-    ch = [{"pts":[],"x":c} for c in c_init]
+    # Generate characteristics starting from uniform x-ordinates
+    ch_spacing = 4
+    for j in range(int(L/ch_spacing)): # For each characteristic
+        x0 = x = ch_spacing*j
+
+        # Empty lists to store x and t values
+        x_ch = [x0, x0]
+        t_ch = [0, 0]
 
-    for i in range(int(T/delta_t)):
-        t = delta_t * i
-        for char in ch:
-            char["pts"].append()
+        # Value of c and vc along this characteristic
+        c_j = c_of_x0(x0)
+        vc_j = vc_of_c(c_j)
+        
+        for i in range(int(T/delta_t)):
+            t = delta_t * i
+            x = x + vc_j*delta_t
+
+            hit_shock = False
+            
+            for s in shocks:
+                # The shock solution is indexed in time from the start of the shock,
+                # so adjust the index accordingly
+                shock_idx = int(i - s["t"]/delta_t)
+                
+                # If the char. crosses the shock, stop
+                if (t > s["t"]) and (
+                        ((x0 > s["x0"]-1e-1) and (x < s["sln"](t)[0]+1e-1)) or
+                        ((x0 < s["x0"]+1e-1) and (x > s["sln"](t)[0]-1e-1))):
+                    hit_shock = True
+                    break
+
+            if hit_shock:
+                break
+            
+            x_ch[-1] = x
+            t_ch[-1] = t
+
+        # Store
+        chars.append({"x": x_ch, "t": t_ch})
+
+    return chars, shocks
+
+def plot_shocks(shocks, delta_t, T, ax):
+    # Plot the shock solutions
+    for s in shocks:
+        t_arr = [i*delta_t+s["t"] for i in range(int((T-s["t"])/delta_t))]
+        y = [s["sln"](t)[0] for t in t_arr]
+        ax.plot(y, t_arr, color="blue")
+
+def plot_chars(chars, ax):
+    for ch in chars:
+        ax.plot(ch["x"], ch["t"], color="red")
     
-def godunov_solve(Q, c_init, c_inflow, delta, t):
-    gstep = lambda c, t: godunov_step(Q, delta, c, c_inflow)
-    return sp_int.odeint(gstep, c_init, t)
     
 def main():
     g, alpha, f = symbols("g,alpha,f")
@@ -105,32 +293,50 @@ def main():
     A_sym = Symbol("A", positive=True, real=True)
 
     l_of_A = derive_wetted_perim(w, y, A_sym)
-    print(l_of_A)
+    #print(l_of_A)
 
     u = simplify(sqrt(A_sym*g*sin(alpha) / (f * l_of_A)))
-    print("u(A) = ")
-    pprint(u)
+    #print("u(A) = ")
+    #pprint(u)
     
     Q = (A_sym*u).subs([("g", 9.81), ("alpha", 0.001), ("f", 0.05), (a, 100)])
-    #plot_expr(diff(Q, A_sym), A_sym, [0,30], 1000)
-    #return
-    Q_lamb = lambdify(A_sym, Q)
 
-    delta_x = 0.1
-    L = 100
+    delta_x = 0.5
+    L = 300
 
-    delta_t = 0.1
+    delta_t = 0.5
     T = 200
+
+    x_sym = Symbol("x", real=True)
+    A_init = 1 + 10000/(10+(x_sym-100)**2)
+    
+    chars, shocks = solve_characteristics(Q, A_sym, A_init,
+                                          x_sym, delta_x, L, delta_t, T)
+
+    ax0 = plt.subplot(2, 1, 1)
+
+    plot_chars(chars, ax0)
+    plot_shocks(shocks, delta_t, T, ax0)
+    ax0.axis([0, L, 0, T])
     
-    x = np.array([i*delta_x for i in range(int(L/delta_x))])
-    A_init = 1 + 1000/(100+(x-10)**2)
+    ax1 = plt.subplot(2, 1, 2)
     
+    #plot_expr(diff(Q, A_sym), A_sym, [0,30], 1000)
+    #return
+    Q_lamb = lambdify(A_sym, Q)
+
+    x = [i*delta_x for i in range(int(L/delta_x))]
+
+    A_init_lamb = lambdify(x_sym, A_init)
+    A_init_data = [A_init_lamb(xi) for xi in x]
     A_inflow = 1
+
     t = [i*delta_t for i in range(int(T/delta_t))]
     
-    c_sln = godunov_solve(Q_lamb, A_init, A_inflow, delta_x, t)
+    c_sln = godunov_solve(Q_lamb, A_init_data, A_inflow, delta_x, t)
     
-    plt.imshow(c_sln, origin="lower", aspect="auto")
+    ax1.imshow(c_sln, origin="lower", aspect="auto")
+
     plt.show()