%% StegerWarming_1D_Euler_Demo.m
% -------------------------------------------------------------------------
%  Steger--Warming flux vector splitting for 1D Euler equations
%
%  课堂演示目标：
%  1) 观察“按特征值正负分裂”的迎风思想；
%  2) 求解 Sod、Lax、强激波、接触间断、双爆波等典型问题；
%  3) 动态显示 rho/u/p/M，并在最后给出 rho 的 x-t contour。
%
%  方程：
%       U_t + F(U)_x = 0
%
%  Steger--Warming 数值通量：
%       F_{i+1/2} = F^+(U_i) + F^-(U_{i+1})
%
%  说明：
%  - 本代码采用一阶空间格式，优点是稳定、易懂；
%  - 接触间断会被明显抹宽，这是课堂上讨论数值耗散的好例子；
%  - 若希望更锐利，需要结合 MUSCL/TVD/ENO/WENO 等高分辨率重构。
% -------------------------------------------------------------------------

clear; close all; clc;

%% ========== 用户参数 ==========
caseName = "sod";       % "sod" | "lax" | "strongShock" | "contact" | "blast"
Nx       = 500;         % 网格数：建议 300--1200
CFL      = 0.48;        % CFL 数
gamma    = 1.4;
plotEvery = 8;          % 动画刷新间隔
useSmoothSplit = true;  % true: sqrt(lambda^2+eps^2) 平滑分裂；false: 原始 |lambda|
epsCoeff = 0.08;        % 平滑强度，0.03--0.15 可试

%% ========== 初始化 ==========
cfg = makeCase(caseName, gamma);
xL = cfg.xL; xR = cfg.xR; tEnd = cfg.tEnd;
x  = linspace(xL, xR, Nx);
dx = x(2) - x(1);

[rho, u, p] = initialCondition(x, cfg);
U = prim2cons(rho, u, p, gamma);

t = 0;
step = 0;

% 保存时空历史，用于最后画 contour
histEvery = max(1, floor(Nx/250));
rhoHist = [];
tHist = [];

fprintf("Case: %s | Nx=%d | CFL=%.3f | tEnd=%.4g\n", caseName, Nx, CFL, tEnd);

%% ========== 图窗 ==========
fig = figure('Name','Steger-Warming 1D Euler Demo',...
    'Color','w','Position',[80 80 1280 760]);

%% ========== 时间推进 ==========
while t < tEnd
    [rho, u, p, c] = cons2prim(U, gamma);
    maxSpeed = max(abs(u) + c);
    dt = CFL * dx / maxSpeed;
    if t + dt > tEnd
        dt = tEnd - t;
    end

    U = stepStegerWarming(U, dx, dt, gamma, useSmoothSplit, epsCoeff);

    t = t + dt;
    step = step + 1;

    if mod(step, histEvery) == 0 || t >= tEnd
        [rhoNow, ~, ~, ~] = cons2prim(U, gamma);
        rhoHist(end+1,:) = rhoNow; %#ok<SAGROW>
        tHist(end+1) = t; %#ok<SAGROW>
    end

    if mod(step, plotEvery) == 0 || t >= tEnd
        plotSolution(fig, x, U, gamma, cfg, t, step);
        drawnow;
    end
end

%% ========== 终态图：数值解与精确解对比 ==========
[rhoN, uN, pN, cN] = cons2prim(U, gamma);
MN = uN ./ cN;

figure('Name','Final solution','Color','w','Position',[120 80 1200 760]);
tl = tiledlayout(2,2,'TileSpacing','compact','Padding','compact');

if cfg.hasExact
    [rhoE, uE, pE] = exactRiemannSolution(x, tEnd, cfg.leftState, cfg.rightState, cfg.x0, gamma);
    ME = uE ./ sqrt(gamma*pE./rhoE);
else
    rhoE = []; uE = []; pE = []; ME = [];
end

nexttile; plotCompare(x, rhoN, rhoE, '\rho'); ylabel('\rho'); grid on;
nexttile; plotCompare(x, uN,   uE,   'u');    ylabel('u');    grid on;
nexttile; plotCompare(x, pN,   pE,   'p');    ylabel('p');    grid on;
nexttile; plotCompare(x, MN,   ME,   'M');    ylabel('M');    grid on;
title(tl, sprintf('Steger--Warming flux vector splitting | %s | t=%.4g', cfg.title, tEnd), ...
    'FontWeight','bold');

%% ========== 时空云图 ==========
figure('Name','Density x-t contour','Color','w','Position',[150 100 1100 460]);
contourf(x, tHist, rhoHist, 40, 'LineStyle','none');
xlabel('x'); ylabel('t');
title(sprintf('\\rho(x,t) contour | %s | Steger--Warming', cfg.title), 'FontWeight','bold');
cb = colorbar; cb.Label.String = '\rho';
grid on;

%% ========== 课堂提示 ==========
fprintf("\nDone.\n");
fprintf("课堂观察建议：\n");
fprintf("1) Sod: 看左行膨胀波、接触间断、右行激波。\n");
fprintf("2) contact: 压力和速度不变，密度间断会被一阶格式抹宽。\n");
fprintf("3) blast: 观察强激波相互作用，一阶格式稳定但耗散明显。\n");
fprintf("4) 修改 useSmoothSplit=false，观察特征值过零处通量分裂的变化。\n");

%% ========================================================================
%%                              local functions
%% ========================================================================

function cfg = makeCase(caseName, gamma)
    cfg.gamma = gamma;
    cfg.hasExact = true;

    switch lower(string(caseName))
        case "sod"
            cfg.title = "Sod shock tube";
            cfg.xL = 0; cfg.xR = 1; cfg.x0 = 0.5; cfg.tEnd = 0.20;
            cfg.leftState  = [1.0,   0.0, 1.0];    % [rho,u,p]
            cfg.rightState = [0.125, 0.0, 0.1];

        case "lax"
            cfg.title = "Lax shock tube";
            cfg.xL = 0; cfg.xR = 1; cfg.x0 = 0.5; cfg.tEnd = 0.14;
            cfg.leftState  = [0.445, 0.698, 3.528];
            cfg.rightState = [0.500, 0.000, 0.571];

        case "strongshock"
            cfg.title = "Strong shock tube";
            cfg.xL = 0; cfg.xR = 1; cfg.x0 = 0.5; cfg.tEnd = 0.012;
            cfg.leftState  = [1.0, 0.0, 1000.0];
            cfg.rightState = [1.0, 0.0, 0.01];

        case "contact"
            cfg.title = "Pure contact discontinuity";
            cfg.xL = 0; cfg.xR = 1; cfg.x0 = 0.5; cfg.tEnd = 0.18;
            cfg.leftState  = [1.0, 0.6, 1.0];
            cfg.rightState = [0.2, 0.6, 1.0];

        case "blast"
            cfg.title = "Two blast waves";
            cfg.xL = 0; cfg.xR = 1; cfg.x0 = 0.5; cfg.tEnd = 0.038;
            cfg.leftState  = [1.0, 0.0, 1000.0];
            cfg.rightState = [1.0, 0.0, 100.0];
            cfg.hasExact = false;

        otherwise
            error('Unknown caseName: %s', caseName);
    end
end

function [rho,u,p] = initialCondition(x, cfg)
    rho = zeros(size(x)); u = rho; p = rho;

    if lower(string(cfg.title)) == "two blast waves"
        rho(:) = 1.0; u(:) = 0.0; p(:) = 0.01;
        p(x < 0.10) = 1000.0;
        p(x > 0.90) = 100.0;
        return;
    end

    L = x < cfg.x0;
    R = ~L;

    rho(L) = cfg.leftState(1);  u(L) = cfg.leftState(2);  p(L) = cfg.leftState(3);
    rho(R) = cfg.rightState(1); u(R) = cfg.rightState(2); p(R) = cfg.rightState(3);
end

function U = prim2cons(rho,u,p,gamma)
    E = p./(gamma-1) + 0.5*rho.*u.^2;  % total energy per volume
    U = [rho; rho.*u; E];
end

function [rho,u,p,c] = cons2prim(U,gamma)
    rho = U(1,:);
    u   = U(2,:)./rho;
    E   = U(3,:);
    p   = (gamma-1)*(E - 0.5*rho.*u.^2);

    % 防止强激波算例中由于数值误差出现非物理负值
    rho = max(rho, 1e-12);
    p   = max(p,   1e-12);
    c   = sqrt(gamma*p./rho);
end

function Unew = stepStegerWarming(U, dx, dt, gamma, useSmoothSplit, epsCoeff)
    N = size(U,2);

    % 两个 ghost cells，透射边界
    Ug = zeros(3,N+4);
    Ug(:,3:N+2) = U;
    Ug(:,1) = U(:,1); Ug(:,2) = U(:,1);
    Ug(:,N+3) = U(:,N); Ug(:,N+4) = U(:,N);

    Fp = zeros(3,N+4);
    Fm = zeros(3,N+4);

    for j = 1:N+4
        [Fp(:,j), Fm(:,j)] = splitFluxSW(Ug(:,j), gamma, useSmoothSplit, epsCoeff);
    end

    % 界面通量 H_{j+1/2}=F^+(U_j)+F^-(U_{j+1})
    H = zeros(3,N+3);
    for j = 1:N+3
        H(:,j) = Fp(:,j) + Fm(:,j+1);
    end

    Unew = U;
    for i = 1:N
        j = i + 2; % physical cell in Ug
        Unew(:,i) = Ug(:,j) - dt/dx * (H(:,j) - H(:,j-1));
    end

    % 简单 positivity guard
    [rho,u,p,~] = cons2prim(Unew, gamma);
    bad = (rho <= 0) | (p <= 0) | ~isfinite(rho) | ~isfinite(p);
    if any(bad)
        warning('Nonphysical state detected. Applying small positivity correction.');
        rho(bad) = max(rho(bad), 1e-10);
        p(bad)   = max(p(bad), 1e-10);
        Unew(:,bad) = prim2cons(rho(bad), u(bad), p(bad), gamma);
    end
end

function [Fp,Fm] = splitFluxSW(U,gamma,useSmoothSplit,epsCoeff)
    rho = U(1);
    u   = U(2)/rho;
    E   = U(3);
    p   = (gamma-1)*(E - 0.5*rho*u^2);
    p   = max(p, 1e-12);
    c   = sqrt(gamma*p/rho);
    H   = (E + p)/rho;

    % Eigenvalues and right eigenvectors
    lam = [u-c; u; u+c];

    R = [1,     1,       1;
         u-c,   u,       u+c;
         H-u*c, 0.5*u^2, H+u*c];

    if useSmoothSplit
        epsLam = epsCoeff * c;
        absLam = sqrt(lam.^2 + epsLam^2);
    else
        absLam = abs(lam);
    end

    lamp = 0.5*(lam + absLam);
    lamm = 0.5*(lam - absLam);

    % left eigenvectors via inverse
    L = inv(R); %#ok<MINV>

    ApU = R * diag(lamp) * L * U;
    AmU = R * diag(lamm) * L * U;

    Fp = ApU;
    Fm = AmU;
end

function plotSolution(fig, x, U, gamma, cfg, t, step)
    figure(fig); clf;
    [rho,u,p,c] = cons2prim(U, gamma);
    M = u./c;

    tiledlayout(2,2,'TileSpacing','compact','Padding','compact');

    nexttile; plot(x,rho,'LineWidth',1.6); grid on;
    ylabel('\rho'); title(sprintf('%s | t=%.4g | step=%d', cfg.title, t, step), 'FontWeight','bold');

    nexttile; plot(x,u,'LineWidth',1.6); grid on;
    ylabel('u');

    nexttile; plot(x,p,'LineWidth',1.6); grid on;
    ylabel('p'); xlabel('x');

    nexttile; plot(x,M,'LineWidth',1.6); grid on;
    ylabel('M'); xlabel('x');
end

function plotCompare(x, yNum, yExact, nameStr)
    plot(x, yNum, '-', 'LineWidth', 1.6); hold on;
    if ~isempty(yExact)
        plot(x, yExact, 'k--', 'LineWidth', 1.2);
        legend('Steger-Warming','Exact','Location','best');
    end
    xlabel('x'); title(nameStr);
end

function [rho,u,p] = exactRiemannSolution(x,t,leftState,rightState,x0,gamma)
    % Exact solution for 1D ideal-gas Riemann problem.
    % leftState/rightState = [rho,u,p].
    rhoL = leftState(1); uL = leftState(2); pL = leftState(3);
    rhoR = rightState(1); uR = rightState(2); pR = rightState(3);

    if t <= 0
        L = x < x0;
        rho = rhoR*ones(size(x)); u = uR*ones(size(x)); p = pR*ones(size(x));
        rho(L)=rhoL; u(L)=uL; p(L)=pL;
        return;
    end

    cL = sqrt(gamma*pL/rhoL);
    cR = sqrt(gamma*pR/rhoR);

    % solve pStar by Newton iteration
    pGuess = max(1e-8, 0.5*(pL+pR) - 0.125*(uR-uL)*(rhoL+rhoR)*(cL+cR));
    pStar = pGuess;

    for k = 1:80
        [fL, dfL] = pressureFunction(pStar, rhoL, pL, cL, gamma);
        [fR, dfR] = pressureFunction(pStar, rhoR, pR, cR, gamma);
        f = fL + fR + (uR-uL);
        df = dfL + dfR;
        pNew = pStar - f/df;
        if pNew < 0
            pNew = 0.5*pStar;
        end
        if abs(pNew-pStar)/(pNew+pStar+1e-12) < 1e-10
            pStar = pNew;
            break;
        end
        pStar = pNew;
    end

    [fL,~] = pressureFunction(pStar, rhoL, pL, cL, gamma);
    [fR,~] = pressureFunction(pStar, rhoR, pR, cR, gamma);
    uStar = 0.5*(uL + uR + fR - fL);

    S = (x - x0)/t;
    rho = zeros(size(x)); u = rho; p = rho;

    for i = 1:numel(x)
        s = S(i);
        if s <= uStar
            % left side of contact
            if pStar > pL
                % left shock
                SL = uL - cL*sqrt((gamma+1)/(2*gamma)*pStar/pL + (gamma-1)/(2*gamma));
                if s <= SL
                    rho(i)=rhoL; u(i)=uL; p(i)=pL;
                else
                    rhoStarL = rhoL*((pStar/pL + (gamma-1)/(gamma+1))/((gamma-1)/(gamma+1)*pStar/pL + 1));
                    rho(i)=rhoStarL; u(i)=uStar; p(i)=pStar;
                end
            else
                % left rarefaction
                cStarL = cL*(pStar/pL)^((gamma-1)/(2*gamma));
                SHL = uL - cL;
                STL = uStar - cStarL;
                if s <= SHL
                    rho(i)=rhoL; u(i)=uL; p(i)=pL;
                elseif s > STL
                    rhoStarL = rhoL*(pStar/pL)^(1/gamma);
                    rho(i)=rhoStarL; u(i)=uStar; p(i)=pStar;
                else
                    uFan = 2/(gamma+1)*(cL + 0.5*(gamma-1)*uL + s);
                    cFan = 2/(gamma+1)*(cL + 0.5*(gamma-1)*(uL - s));
                    rho(i)=rhoL*(cFan/cL)^(2/(gamma-1));
                    u(i)=uFan;
                    p(i)=pL*(cFan/cL)^(2*gamma/(gamma-1));
                end
            end
        else
            % right side of contact
            if pStar > pR
                % right shock
                SR = uR + cR*sqrt((gamma+1)/(2*gamma)*pStar/pR + (gamma-1)/(2*gamma));
                if s >= SR
                    rho(i)=rhoR; u(i)=uR; p(i)=pR;
                else
                    rhoStarR = rhoR*((pStar/pR + (gamma-1)/(gamma+1))/((gamma-1)/(gamma+1)*pStar/pR + 1));
                    rho(i)=rhoStarR; u(i)=uStar; p(i)=pStar;
                end
            else
                % right rarefaction
                cStarR = cR*(pStar/pR)^((gamma-1)/(2*gamma));
                SHR = uR + cR;
                STR = uStar + cStarR;
                if s >= SHR
                    rho(i)=rhoR; u(i)=uR; p(i)=pR;
                elseif s <= STR
                    rhoStarR = rhoR*(pStar/pR)^(1/gamma);
                    rho(i)=rhoStarR; u(i)=uStar; p(i)=pStar;
                else
                    uFan = 2/(gamma+1)*(-cR + 0.5*(gamma-1)*uR + s);
                    cFan = 2/(gamma+1)*(cR - 0.5*(gamma-1)*(uR - s));
                    rho(i)=rhoR*(cFan/cR)^(2/(gamma-1));
                    u(i)=uFan;
                    p(i)=pR*(cFan/cR)^(2*gamma/(gamma-1));
                end
            end
        end
    end
end

function [f,df] = pressureFunction(p, rhoK, pK, cK, gamma)
    if p > pK
        AK = 2/((gamma+1)*rhoK);
        BK = (gamma-1)/(gamma+1)*pK;
        f = (p-pK)*sqrt(AK/(p+BK));
        df = sqrt(AK/(p+BK))*(1 - 0.5*(p-pK)/(p+BK));
    else
        pr = p/pK;
        f = 2*cK/(gamma-1)*(pr^((gamma-1)/(2*gamma)) - 1);
        df = (1/(rhoK*cK))*pr^(-(gamma+1)/(2*gamma));
    end
end
