import cvxpy as cp
import numpy as np
import scipy.linalg as la

def fix_density_matrix_max_entropy(rho, max_iters=1000, trace_penalty=1e6, positivity_penalty=1e6):
    """
    Fixes the off-diagonal elements of a density matrix and maximizes its entropy.
    rho: Input density matrix (numpy array)
    max_iters: Maximum number of iterations for the optimization solver
    trace_penalty: Penalty coefficient for the trace constraint
    positivity_penalty: Penalty coefficient for the positivity constraint
    """
    n = rho.shape[0]

    # Define eigenvalues as cvxpy variables
    eigenvalues = cp.Variable(n, nonneg=True)

    # Define the Hermitian density matrix as a cvxpy variable
    rho_var = cp.Variable((n, n), complex=True)

    # Ensure the matrix is Hermitian
    constraints = [rho_var == cp.conj(rho_var.T)]

    # Apply fixed off-diagonal elements
    for i in range(n):
        for j in range(n):
            if i != j:  # Fix off-diagonal elements
                constraints.append(rho_var[i, j] == rho[i, j])

    # Constraint to ensure positive semidefiniteness (explicitly enforce positive semidefiniteness)
    constraints.append(rho_var >> 0)

    # Constraint on eigenvalues
    constraints.append(cp.sum(eigenvalues) == 1)  # Sum of eigenvalues must equal 1

    # Small value for numerical stability
    epsilon = 1e-10

    # Penalty term for the trace constraint
    trace_error = cp.abs(cp.trace(rho_var) - 1)

    # Penalty term for positive semidefiniteness
    # Add a penalty if the smallest eigenvalue is negative
    min_eigenvalue = cp.lambda_min(rho_var)
    positivity_error = cp.maximum(-min_eigenvalue, 0)

    # Objective function: Maximize entropy + penalties for constraints
    obj = cp.Minimize(
        cp.sum(-cp.entr(eigenvalues + epsilon)) +
        trace_penalty * trace_error +
        positivity_penalty * positivity_error
    )

    # Define the optimization problem
    prob = cp.Problem(obj, constraints)

    # Suppress debug output by setting verbose to False
    try:
        prob.solve(solver=cp.SCS, verbose=False, max_iters=max_iters)
    except cp.error.SolverError:
        # If SCS solver fails, try MOSEK (if available)
        try:
            prob.solve(solver=cp.MOSEK, verbose=False)
        except cp.error.SolverError:
            raise ValueError("Optimization failed.")

    if prob.status not in ["optimal", "optimal_inaccurate"]:
        raise ValueError(f"Optimization failed: {prob.status}")

    return rho_var.value

def generate_density_matrix(n, seed, max_eigenvalue=1.0):
    """Generates a random density matrix."""
    np.random.seed(seed)

    while True:
        # Generate a random Hermitian matrix
        a = np.random.rand(n, n) + 1j * np.random.rand(n, n)
        h = (a + a.conj().T) / 2  # Define h outside the loop

        # Ensure positive definiteness using the exponential map
        rho = la.expm(h)
        rho /= np.trace(rho)

        # Convert diagonal elements to real numbers
        np.fill_diagonal(rho, np.real(np.diag(rho)))

        # Limit the maximum eigenvalue
        eigenvalues = np.linalg.eigvals(rho)
        if np.max(np.abs(eigenvalues)) <= max_eigenvalue:
            return rho

def main(seeds, dimensions):
    """
    Fixes density matrices for multiple seeds and dimensions, and outputs the results.
    seeds: List of seed values
    dimensions: List of dimensions
    """
    for n in dimensions:
        for seed in seeds:
            print(f"\n=== Final Result, Dimension {n}, Seed {seed} ===")

            # Generate the original density matrix
            rho = generate_density_matrix(n, seed)

            # Compute the corrected density matrix
            rho_fixed = fix_density_matrix_max_entropy(rho)

            # Compute eigenvalues
            eigenvalues = np.linalg.eigvals(rho_fixed)

            # Check positive semidefiniteness
            is_positive_semidefinite = np.all(eigenvalues >= -1e-10)

            # Check the trace
            trace_fixed = np.trace(rho_fixed)

            # Maximum error
            max_error = np.max(np.abs(rho_fixed - rho))

            # Output the results
            print("Original Density Matrix:")
            print(rho)
            print("\nCorrected Density Matrix:")
            print(rho_fixed)
            print("\nEigenvalues:", eigenvalues)
            print("Positive Semidefiniteness Check:", is_positive_semidefinite)
            print(f"Trace: {trace_fixed:.15f}")
            print("Maximum Error:", max_error)

# Example execution
seeds = [1]  # List of seed values
dimensions = [50]  # List of dimensions
main(seeds, dimensions)