#include <stdio.h>
#include <math.h>
#include <complex.h>

#define N 4
#define EPS 1e-6
#define MAX_ITER 100

// MATRIX 1 definition (stored as complex numbers)
double complex A[N][N] = { 
    {1.0 + 0.0*I, 2.0 + 0.0*I, 3.0 + 0.0*I, 5.0 + 0.0*I}, 
    {3.0 + 0.0*I, -5.0 + 0.0*I, 1.0 + 0.0*I, 4.0 + 0.0*I}, 
    {5.0 + 0.0*I, 9.0 + 0.0*I, 2.0 + 0.0*I, -6.0 + 0.0*I}, 
    {1.0 + 0.0*I, 7.0 + 0.0*I, 4.0 + 0.0*I, 1.0 + 0.0*I} 
};

// Complex LU Decomposition with Partial Pivoting
void lu_decomposition_complex(double complex B[N][N], double complex L[N][N], double complex U[N][N], int P[N]) {
    double complex A_temp[N][N];
    for (int i = 0; i < N; i++) {
        P[i] = i;
        for (int j = 0; j < N; j++) {
            A_temp[i][j] = B[i][j];
            L[i][j] = (i == j) ? 1.0 + 0.0*I : 0.0 + 0.0*I;
            U[i][j] = 0.0 + 0.0*I;
        }
    }

    for (int i = 0; i < N; i++) {
        int pivot_row = i;
        double max_val = cabs(A_temp[i][i]); // cabs computes the magnitude of complex number
        for (int r = i + 1; r < N; r++) {
            if (cabs(A_temp[r][i]) > max_val) {
                max_val = cabs(A_temp[r][i]);
                pivot_row = r;
            }
        }

        if (pivot_row != i) {
            int t = P[i]; P[i] = P[pivot_row]; P[pivot_row] = t;
            for (int k = 0; k < N; k++) {
                double complex tmp = A_temp[i][k];
                A_temp[i][k] = A_temp[pivot_row][k];
                A_temp[pivot_row][k] = tmp;
            }
        }

        for (int j = i; j < N; j++) {
            double complex sum = 0.0;
            for (int k = 0; k < i; k++) sum += L[i][k] * U[k][j];
            U[i][j] = A_temp[i][j] - sum;
        }
        for (int j = i + 1; j < N; j++) {
            double complex sum = 0.0;
            for (int k = 0; k < i; k++) sum += L[j][k] * U[k][i];
            L[j][i] = (A_temp[j][i] - sum) / U[i][i];
        }
    }
}

// Forward Substitution
void forward_substitution(double complex L[N][N], double complex x[N], double complex y[N], int P[N]) {
    double complex px[N];
    for (int i = 0; i < N; i++) px[i] = x[P[i]];
    for (int i = 0; i < N; i++) {
        double complex sum = 0.0;
        for (int j = 0; j < i; j++) sum += L[i][j] * y[j];
        y[i] = px[i] - sum;
    }
}

// Backward Substitution
void backward_substitution(double complex U[N][N], double complex y[N], double complex x[N]) {
    for (int i = N - 1; i >= 0; i--) {
        double complex sum = 0.0;
        for (int j = i + 1; j < N; j++) sum += U[i][j] * x[j];
        x[i] = (y[i] - sum) / U[i][i];
    }
}

// Complex Inverse Iteration Routine
void find_eigen_complex(double complex lambda_hat) {
    double complex B[N][N], L[N][N], U[N][N];
    double complex y[N], x_next[N];
    int P[N];
    
    // Initial guess vector x^(0) = [1, 1, 1, 1]^T
    double complex x[N] = {1.0 + 0.0*I, 1.0 + 0.0*I, 1.0 + 0.0*I, 1.0 + 0.0*I};

    for (int i = 0; i < N; i++) {
        for (int j = 0; j < N; j++) {
            B[i][j] = A[i][j] - (i == j ? lambda_hat : 0.0);
        }
    }

    lu_decomposition_complex(B, L, U, P);

    printf("\n==================================================================\n");
    printf(" Inverse Iteration with Complex Shift: lambda_hat = %.2f + %.2fi\n", creal(lambda_hat), cimag(lambda_hat));
    printf("==================================================================\n");
    
    int iter = 0;
    double diff;
    double complex lambda = 0.0;

    do {
        forward_substitution(L, x, y, P);
        backward_substitution(U, y, x_next);

        // Scaling factor based on max absolute value (magnitude)
        double complex max_comp = x_next[0];
        double max_mag = cabs(x_next[0]);
        for (int i = 1; i < N; i++) {
            if (cabs(x_next[i]) > max_mag) {
                max_mag = cabs(x_next[i]);
                max_comp = x_next[i];
            }
        }

        // Normalize vector entries
        for (int i = 0; i < N; i++) {
            x_next[i] /= max_comp;
        }

        // Compute vector diff convergence
        diff = 0.0;
        for (int i = 0; i < N; i++) {
            diff += cabs(x_next[i] - x[i]);
        }

        for (int i = 0; i < N; i++) {
            x[i] = x_next[i];
        }

        lambda = lambda_hat + (1.0 / max_comp);
        
        printf("Iter %2d: lambda = %7.4f + %7.4fi\n", iter + 1, creal(lambda), cimag(lambda));

        iter++;
    } while (diff > EPS && iter < MAX_ITER);

    printf("\n CONVERGED SOLUTION:\n");
    printf(" Eigenvalue (lambda) = %9.4f + %9.4fi\n", creal(lambda), cimag(lambda));
    printf(" Eigenvector (x)    = [%.4f+%.4fi, %.4f+%.4fi, %.4f+%.4fi, %.4f+%.4fi]\n", 
           creal(x[0]), cimag(x[0]), creal(x[1]), cimag(x[1]), 
           creal(x[2]), cimag(x[2]), creal(x[3]), cimag(x[3]));
    printf("==================================================================\n");
}

int main() {
    printf("=== Complex Inverse Iteration Analysis for Matrix 1 ===\n");
    
    // 1. Finding the known real negative root
    find_eigen_complex(-5.0 + 0.0*I); 
    
    // 2. Finding the complex conjugate pair by supplying complex guesses near the spectrum floor
    find_eigen_complex(-2.0 + 1.5*I); // Targets upper complex plane root
    find_eigen_complex(-2.0 - 1.5*I); // Targets lower complex plane root
    
    return 0;
}