module Numerics.LinearEquationSystem (
    solveTridiagonal
) where

import           Algebra.VectorSpace
import           Data.Vector         (Vector, (!))
import qualified Data.Vector         as V

-- | Solve a tridiagonal system of equations.
--
-- \[
-- \begin{pmatrix}
--     b_0 & c_0 &        &         &         \\
--     a_0 & b_1 & c_1    &         &         \\
--         & a_1 & b_2    & \ddots  &         \\
--         &     & \ddots & \ddots  & c_{n-2} \\
--         &     &        & a_{n-2} & b_{n-1}
-- \end{pmatrix}
-- \begin{pmatrix}
--     x_0     \\
--     x_1     \\
--     x_2     \\
--     \vdots  \\
--     x_{n-1}
-- \end{pmatrix}
-- =
-- \begin{pmatrix}
--     d_0     \\
--     d_1     \\
--     d_2     \\
--     \vdots  \\
--     d_{n-1}
-- \end{pmatrix}
-- \]
--
-- Translated with blood, sweat and tears from 1-and-2(!!)-based indexing
-- at https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm
solveTridiagonal
    :: VectorSpace vec
    => Vector Double -- ^ Lower diagonal, length n-1. \(a_0 \ldots a_{n-2}\)
    -> Vector Double -- ^ Diagonal, length n. \(b_0 \ldots b_{n-1}\)
    -> Vector Double -- ^ Upper diagonal, length n-1. \(c_0 \ldots c_{n-2}\)
    -> Vector vec    -- ^ RHS, length n. \(d_0 \ldots d_{n-1}\)
    -> Vector vec    -- ^ Solution, length n. \(x_0 \ldots x_{n-1}\)
solveTridiagonal :: forall vec.
VectorSpace vec =>
Vector Double
-> Vector Double -> Vector Double -> Vector vec -> Vector vec
solveTridiagonal Vector Double
a Vector Double
b Vector Double
c Vector vec
d
  = let n :: Int
n = Vector vec -> Int
forall a. Vector a -> Int
V.length Vector vec
d
        ifor :: Vector a -> (Int -> a -> b) -> Vector b
ifor = ((Int -> a -> b) -> Vector a -> Vector b)
-> Vector a -> (Int -> a -> b) -> Vector b
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> a -> b) -> Vector a -> Vector b
forall a b. (Int -> a -> b) -> Vector a -> Vector b
V.imap
        c' :: Vector Double
c' = Vector Double -> (Int -> Double -> Double) -> Vector Double
forall {a} {b}. Vector a -> (Int -> a -> b) -> Vector b
ifor Vector Double
c ((Int -> Double -> Double) -> Vector Double)
-> (Int -> Double -> Double) -> Vector Double
forall a b. (a -> b) -> a -> b
$ \Int
i Double
c_i -> case Int
i of
            Int
0 -> Double
c_i Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Vector Double
bVector Double -> Int -> Double
forall a. Vector a -> Int -> a
!Int
i
            Int
_ -> Double
c_i Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Vector Double
bVector Double -> Int -> Double
forall a. Vector a -> Int -> a
!Int
i Double -> Double -> Double
forall a. Num a => a -> a -> a
- Vector Double
aVector Double -> Int -> Double
forall a. Vector a -> Int -> a
!(Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Vector Double
c'Vector Double -> Int -> Double
forall a. Vector a -> Int -> a
!(Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1))
        d' :: Vector vec
d' = Vector vec -> (Int -> vec -> vec) -> Vector vec
forall {a} {b}. Vector a -> (Int -> a -> b) -> Vector b
ifor Vector vec
d ((Int -> vec -> vec) -> Vector vec)
-> (Int -> vec -> vec) -> Vector vec
forall a b. (a -> b) -> a -> b
$ \Int
i vec
d_i -> case Int
i of
            Int
0 -> vec
d_i vec -> Double -> vec
forall v. VectorSpace v => v -> Double -> v
/. Vector Double
bVector Double -> Int -> Double
forall a. Vector a -> Int -> a
!Int
i
            Int
_ -> (vec
d_i vec -> vec -> vec
forall v. VectorSpace v => v -> v -> v
-. Vector Double
aVector Double -> Int -> Double
forall a. Vector a -> Int -> a
!(Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Double -> vec -> vec
forall v. VectorSpace v => Double -> v -> v
*. Vector vec
d'Vector vec -> Int -> vec
forall a. Vector a -> Int -> a
!(Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1))  vec -> Double -> vec
forall v. VectorSpace v => v -> Double -> v
/.  (Vector Double
bVector Double -> Int -> Double
forall a. Vector a -> Int -> a
!Int
i Double -> Double -> Double
forall a. Num a => a -> a -> a
- Vector Double
aVector Double -> Int -> Double
forall a. Vector a -> Int -> a
!(Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Vector Double
c'Vector Double -> Int -> Double
forall a. Vector a -> Int -> a
!(Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1))
        x :: Vector vec
x = Vector vec -> (Int -> vec -> vec) -> Vector vec
forall {a} {b}. Vector a -> (Int -> a -> b) -> Vector b
ifor Vector vec
d' ((Int -> vec -> vec) -> Vector vec)
-> (Int -> vec -> vec) -> Vector vec
forall a b. (a -> b) -> a -> b
$ \Int
i vec
d'_i -> case Int
i of
            Int
_ | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 -> vec
d'_i
            Int
_            -> vec
d'_i vec -> vec -> vec
forall v. VectorSpace v => v -> v -> v
-. Vector Double
c'Vector Double -> Int -> Double
forall a. Vector a -> Int -> a
!Int
i Double -> vec -> vec
forall v. VectorSpace v => Double -> v -> v
*. Vector vec
xVector vec -> Int -> vec
forall a. Vector a -> Int -> a
!(Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
    in Vector vec
x