import numpy as np
import matplotlib.pyplot as plt

np.set_printoptions(suppress=True)


def heatFTCS(nt=10, nx=20, alpha=0.1, L=1, tmax=0.5, errPlots=1):

    dx = L / (nx - 1)
    dt = tmax / (nt - 1)
    r = alpha * dt / dx / dx
    r2 = 1 - 2 * r

    x = np.linspace(0, L, nx).conj().transpose()
    t = np.linspace(0, tmax, nt)
    U = np.zeros((nx, nt))

    U[:, 0] = np.sin(np.pi * x / L)

    for m in range(1, nt):
        for i in range(1, nx-1):
            U[i, m] = r * U[i - 1, m - 1] + r2 * U[i, m-1] + r * U[i+1, m-1]

    ue = np.sin(np.pi * x / L) * \
        np.exp(-t[nt - 1] * alpha * (np.pi / L) * (np.pi / L))
    err = np.linalg.norm(U[:, nt - 1] - ue)

    errout = err
    xo = x
    to = t
    Uo = U

    if errPlots == 0:
        return errout, xo, to, Uo

    _, ax = plt.subplots()
    ax.plot(x, U[:, nt - 1], 'o--', label='FCTS')
    ax.plot(x, ue, '-', label='Exact')
    plt.xlabel("x")
    plt.ylabel("u")
    ax.legend()

    plt.figure()
    plt.plot(x, U[:, nt - 1] - ue, 'o--')
    plt.xlabel('x')
    plt.ylabel('u - ue')
    plt.show()


def heatBTCS(nt=10, nx=20, alpha=0.1, L=1, tmax=0.5, errPlots=1):
    dx = L / (nx-1)
    dt = tmax / (nt-1)
    x = np.linspace(0, L, nx)
    t = np.linspace(0, tmax, nt)
    U = np.zeros((nx, nt))

    U[:, 0] = np.sin(np.pi * x / L)
    u0 = 0
    uL = 0

    a = (-alpha / dx / dx) * np.ones((nx, 1), float)
    c = a
    b = (1 / dt) * np.ones((nx, 1), float) - 2 * a
    b[0] = 1
    b[len(b) - 1] = 1
    c[0] = 0
    a[len(a) - 1] = 0
    e, f = tridiagLU(a, b, c)

    for i in range(1, nt):
        d = U[:, i-1] / dt
        d[0] = u0
        d[len(d)-1] = uL
        U[:, i] = tridiagLUSolve(d, a, e, f, U[:, i - 1])

    ue = np.sin(np.pi * x / L) * \
        np.exp(-t[nt - 1] * alpha * (np.pi / L) * (np.pi / L))
    err = np.linalg.norm(U[:, nt - 1] - np.transpose(ue))

    errout = err
    xo = x
    to = t
    Uo = U

    if errPlots == 0:
        return errout, xo, to, Uo

    _, ax = plt.subplots()
    ax.plot(x, U[:, nt - 1], 'o--', label='BTCS')
    ax.plot(x, ue, '-', label='Exact')
    plt.xlabel("x")
    plt.ylabel("u")
    ax.legend()
    plt.show()


def heatCN(nt=10, nx=20, alpha=0.1, L=1, tmax=0.5, errPlots=1):

    dx = L / (nx - 1)
    dt = tmax / (nt - 1)

    x = np.linspace(0, L, nx)
    t = np.linspace(0, tmax, nt)
    U = np.zeros((nx, nt))

    U[:, 0] = np.sin(np.pi * x / L)
    u0 = 0
    uL = 0

    a = (- alpha / 2 / dx / dx) * np.ones((nx, 1), float)
    c = (- alpha / 2 / dx / dx) * np.ones((nx, 1), float)
    b = (1 / dt) * np.ones((nx, 1), float) - (a + c)
    b[0] = 1
    c[0] = 0
    b[len(b)-1] = 1
    a[len(a)-1] = 0

    e, f = tridiagLU(a, b, c)

    a = a.reshape(20)
    c = c.reshape(20)

    for i in range(0, nt):

        a1 = np.array([0])
        a1 = np.append(a1, a[1: len(a) - 1] * U[0: nx - 2, i])
        a1 = np.append(a1, [0])

        a2 = np.array([0])
        a2 = np.append(
            a2, (a[1: len(a) - 1] + c[1: len(c) - 1]) * U[1: nx - 1, i])
        a2 = np.append(a2, [0])

        a3 = np.array([0])
        a3 = np.append(a3, c[1: len(c) - 1] * U[2: nx, i])
        a3 = np.append(a3, [0])

        d = U[:, i] / dt - a1 + a2 - a3

        d[0] = u0
        d[len(d) - 1] = uL
        if i < nt - 1:
            U[:, i + 1] = tridiagLUSolve(d, a, e, f, U[:, i])

    ue = np.sin(np.pi * x / L) * \
        np.exp(-t[nt - 1] * alpha * (np.pi / L) * (np.pi / L))
    err = np.linalg.norm(U[:, nt - 1] - ue)

    errout = err
    xo = x
    to = t
    Uo = U

    if errPlots == 0:
        return errout, xo, to, Uo

    _, ax = plt.subplots()
    ax.plot(x, U[:, nt - 1], 'o--', label='FTCS')
    ax.plot(x, ue, '-', label='Exact')
    plt.xlabel("x")
    plt.ylabel("u")
    ax.legend()

    y = U[0, nt - 1] - ue

    for i in range(1, nx):
        tmp = U[i, nt - 1] - ue
        y = np.vstack((y, tmp))

    plt.figure()
    plt.plot(x, y, 'o--')
    plt.xlabel('x')
    plt.ylabel('u - ue')

    plt.show()


def tridiagLU(a, b, c):

    n = len(a)
    e = np.zeros((n, 1))
    f = np.zeros((n, 1))

    e[0] = b[0]
    f[0] = c[0] / b[0]

    for i in range(1, n):
        e[i] = b[i] - a[i]*f[i-1]
        f[i] = c[i] / e[i]

    return e, f


def tridiagLUSolve(d, a, e, f, v):
    n = len(d)
    v[0] = 1.0 * d[0] / 1.0 * e[0]

    for i in range(1, n):
        v[i] = (d[i] - a[i] * v[i-1]) / e[i]

    for i in range(n-2, -1, -1):
        v[i] = v[i] - f[i]*v[i+1]

    return v


heatCN()
heatFTCS()
heatBTCS()
