%% shock_tube_fd_exact_demo.m
% 一维 Euler 激波管：有限差分/守恒差分动态解 + Riemann 精确解对比
% -------------------------------------------------------------------------
% 
% 1. 用动态动画展示激波管中膨胀波、接触间断、激波的传播；
% 2. 用精确 Riemann 解作为基准，帮助理解数值耗散与间断捕捉；
% 3. 在动态图中保留初始条件作为参照，便于比较波系如何从间断发展；
% 4. 在 x-t 平面展示三族波的大致轨迹：
%       lambda_1 = u-a,  lambda_2 = u,  lambda_3 = u+a
%
% 说明：
% - 数值格式默认使用 Lax-Friedrichs / Rusanov 型守恒差分通量；
% - 可切换为 MacCormack 格式，用于展示激波附近非物理振荡；
% - 精确解采用理想气体 Euler 方程 Riemann 问题解析结构，并通过 Newton 迭代求 p_star。
%
% 建议：
% 先运行 method='LF'，学生能看到稳定但耗散较大的激波和接触间断；
% 再运行 method='MacCormack'，学生能看到高阶中心型格式在激波附近可能振荡；
% 最后打开 showXT=true，解释三条波线和三族特征速度。

clear; clc; close all;

%% ======================= 用户可调参数 =======================
caseName = 'Sod';          % 'Sod' | 'Lax' | 'StrongShock' | 'DoubleRarefaction'
method   = 'LF';           % 'LF' | 'Rusanov' | 'MacCormack'
gamma    = 1.4;

N        = 500;            % 网格数
xMin     = 0.0;
xMax     = 1.0;
x0       = 0.5;            % 初始间断位置
CFL      = 0.55;
tEnd     = 0.20;
plotEvery = 6;             % 每隔多少步画一次
showXT   = true;           % 是否画 x-t 波系示意
showInitialReference = true; % 动图中是否保留初始条件作为参照

% y轴是否自动变化；false 更适合课堂观察动态
autoY = false;

%% ======================= 初始条件 =======================
[leftState,rightState,tEndDefault] = getShockTubeCase(caseName);
if isempty(tEnd)
    tEnd = tEndDefault;
end

rhoL = leftState(1); uL = leftState(2); pL = leftState(3);
rhoR = rightState(1); uR = rightState(2); pR = rightState(3);

x  = linspace(xMin,xMax,N)';
dx = x(2)-x(1);

rho = rhoL*(x<x0) + rhoR*(x>=x0);
u   = uL  *(x<x0) + uR  *(x>=x0);
p   = pL  *(x<x0) + pR  *(x>=x0);

U = prim2cons(rho,u,p,gamma);

% 保存初始条件，用于动画中作为固定参照线
rho0 = rho;
u0   = u;
p0   = p;
M0   = u0./sqrt(gamma*p0./rho0);

%% ======================= 预计算精确波系 =======================
exactInfo = exactRiemannInfo(leftState,rightState,gamma);

fprintf('\nCase: %s\n',caseName);
fprintf('Method: %s\n',method);
fprintf('p_star = %.8f, u_star = %.8f\n',exactInfo.pStar,exactInfo.uStar);
fprintf('Left wave : %s\n',exactInfo.leftWave);
fprintf('Right wave: %s\n\n',exactInfo.rightWave);

%% ======================= 作图初始化 =======================
fig = figure('Color','w','Position',[80 80 1250 760]);
tl = tiledlayout(fig,2,2,'TileSpacing','compact','Padding','compact');

ax1 = nexttile; hold(ax1,'on'); box(ax1,'on');
hRhoInit = plot(ax1,x,rho0,':','LineWidth',1.2);
hRhoNum  = plot(ax1,x,rho,'LineWidth',1.6);
hRhoEx   = plot(ax1,x,rho,'--','LineWidth',1.4);
ylabel(ax1,'\rho'); title(ax1,'Density');
if showInitialReference
    set(hRhoInit,'Visible','on');
    legend(ax1,{'initial','numerical','exact'},'Location','best');
else
    set(hRhoInit,'Visible','off');
    legend(ax1,{'numerical','exact'},'Location','best');
end

ax2 = nexttile; hold(ax2,'on'); box(ax2,'on');
hUInit = plot(ax2,x,u0,':','LineWidth',1.2);
hUNum  = plot(ax2,x,u,'LineWidth',1.6);
hUEx   = plot(ax2,x,u,'--','LineWidth',1.4);
if ~showInitialReference, set(hUInit,'Visible','off'); end
ylabel(ax2,'u'); title(ax2,'Velocity');

ax3 = nexttile; hold(ax3,'on'); box(ax3,'on');
hPInit = plot(ax3,x,p0,':','LineWidth',1.2);
hPNum  = plot(ax3,x,p,'LineWidth',1.6);
hPEx   = plot(ax3,x,p,'--','LineWidth',1.4);
if ~showInitialReference, set(hPInit,'Visible','off'); end
ylabel(ax3,'p'); xlabel(ax3,'x'); title(ax3,'Pressure');

ax4 = nexttile; hold(ax4,'on'); box(ax4,'on');
a = sqrt(gamma*p./rho);
M = u./a;
hMInit = plot(ax4,x,M0,':','LineWidth',1.2);
hMNum  = plot(ax4,x,M,'LineWidth',1.6);
hMEx   = plot(ax4,x,M,'--','LineWidth',1.4);
if ~showInitialReference, set(hMInit,'Visible','off'); end
ylabel(ax4,'M'); xlabel(ax4,'x'); title(ax4,'Mach number');

title(tl,sprintf('1D Euler shock tube: %s, %s, t = %.4f',caseName,method,0), ...
    'FontWeight','bold');

if ~autoY
    [rhoE,uE,pE,ME] = exactSolution(x,tEnd,x0,leftState,rightState,gamma);
    set(ax1,'YLim',expandLimits([rhoE;rho;rho0]));
    set(ax2,'YLim',expandLimits([uE;u;u0]));
    set(ax3,'YLim',expandLimits([pE;p;p0]));
    set(ax4,'YLim',expandLimits([ME;M;M0]));
end

%% ======================= 动态时间推进 =======================
t = 0.0;
step = 0;

while t < tEnd
    [rho,u,p] = cons2prim(U,gamma);
    a = sqrt(gamma*p./rho);
    dt = CFL*dx/max(abs(u)+a);
    if t+dt > tEnd
        dt = tEnd-t;
    end

    switch lower(method)
        case {'lf','rusanov'}
            U = stepRusanov(U,gamma,dx,dt);
        case 'maccormack'
            U = stepMacCormack(U,gamma,dx,dt);
        otherwise
            error('Unknown method: %s',method);
    end

    % 简单的透射边界
    U(1,:)   = U(2,:);
    U(end,:) = U(end-1,:);

    t = t+dt;
    step = step+1;

    if mod(step,plotEvery)==0 || t>=tEnd
        [rho,u,p] = cons2prim(U,gamma);
        a = sqrt(gamma*p./rho);
        M = u./a;

        [rhoEx,uEx,pEx,MEx] = exactSolution(x,t,x0,leftState,rightState,gamma);

        set(hRhoNum,'YData',rho);
        set(hRhoEx, 'YData',rhoEx);
        set(hUNum,  'YData',u);
        set(hUEx,   'YData',uEx);
        set(hPNum,  'YData',p);
        set(hPEx,   'YData',pEx);
        set(hMNum,  'YData',M);
        set(hMEx,   'YData',MEx);

        title(tl,sprintf('1D Euler shock tube: %s, %s, t = %.4f', ...
            caseName,method,t),'FontWeight','bold');

        drawnow;
    end
end

%% ======================= x-t 波系示意 =======================
if showXT
    plotXTDiagram(xMin,xMax,x0,tEnd,leftState,rightState,gamma,exactInfo);
end

%% ========================================================================
%                              局部函数
% ========================================================================

function [leftState,rightState,tEnd] = getShockTubeCase(caseName)
    switch lower(caseName)
        case 'sod'
            leftState  = [1.0,   0.0, 1.0];
            rightState = [0.125, 0.0, 0.1];
            tEnd = 0.20;
        case 'lax'
            leftState  = [0.445, 0.698, 3.528];
            rightState = [0.500, 0.000, 0.571];
            tEnd = 0.16;
        case 'strongshock'
            leftState  = [1.0, 0.0, 1000.0];
            rightState = [1.0, 0.0, 0.01];
            tEnd = 0.012;
        case 'doublerarefaction'
            leftState  = [1.0, -2.0, 0.4];
            rightState = [1.0,  2.0, 0.4];
            tEnd = 0.15;
        otherwise
            error('Unknown caseName.');
    end
end

function U = prim2cons(rho,u,p,gamma)
    E = p./((gamma-1)*rho) + 0.5*u.^2;
    U = [rho, rho.*u, rho.*E];
end

function [rho,u,p] = cons2prim(U,gamma)
    rho = U(:,1);
    u   = U(:,2)./rho;
    E   = U(:,3)./rho;
    p   = (gamma-1)*rho.*(E-0.5*u.^2);

    % 防止教学演示中 MacCormack 在强激波附近导致负压后程序崩溃
    rho = max(rho,1e-12);
    p   = max(p,1e-12);
end

function F = fluxEuler(U,gamma)
    [rho,u,p] = cons2prim(U,gamma);
    E = U(:,3)./rho;
    F = [rho.*u, rho.*u.^2+p, u.*(rho.*E+p)];
end

function Unew = stepRusanov(U,gamma,dx,dt)
    N = size(U,1);
    F = fluxEuler(U,gamma);
    numFlux = zeros(N-1,3);

    for i = 1:N-1
        UL = U(i,:);
        UR = U(i+1,:);
        FL = F(i,:);
        FR = F(i+1,:);

        [rhoL,uL,pL] = cons2prim(UL,gamma);
        [rhoR,uR,pR] = cons2prim(UR,gamma);
        aL = sqrt(gamma*pL/rhoL);
        aR = sqrt(gamma*pR/rhoR);
        smax = max(abs(uL)+aL,abs(uR)+aR);

        numFlux(i,:) = 0.5*(FL+FR) - 0.5*smax*(UR-UL);
    end

    Unew = U;
    for i = 2:N-1
        Unew(i,:) = U(i,:) - dt/dx*(numFlux(i,:)-numFlux(i-1,:));
    end
end

function Unew = stepMacCormack(U,gamma,dx,dt)
    N = size(U,1);

    % predictor：前向差分
    F = fluxEuler(U,gamma);
    Up = U;
    for i = 2:N-1
        Up(i,:) = U(i,:) - dt/dx*(F(i+1,:)-F(i,:));
    end
    Up(1,:) = Up(2,:);
    Up(N,:) = Up(N-1,:);

    % corrector：后向差分
    Fp = fluxEuler(Up,gamma);
    Unew = U;
    for i = 2:N-1
        Unew(i,:) = 0.5*(U(i,:) + Up(i,:) - dt/dx*(Fp(i,:)-Fp(i-1,:)));
    end

    % 少量人工粘性，防止课堂演示中直接爆掉；可设为 0 观察振荡
    epsAV = 0.02;
    Unew(2:N-1,:) = Unew(2:N-1,:) + epsAV*(U(3:N,:)-2*U(2:N-1,:)+U(1:N-2,:));
end

function info = exactRiemannInfo(leftState,rightState,gamma)
    rhoL = leftState(1); uL = leftState(2); pL = leftState(3);
    rhoR = rightState(1); uR = rightState(2); pR = rightState(3);

    aL = sqrt(gamma*pL/rhoL);
    aR = sqrt(gamma*pR/rhoR);

    pPV = 0.5*(pL+pR) - 0.125*(uR-uL)*(rhoL+rhoR)*(aL+aR);
    pOld = max(1e-8,pPV);

    for k = 1:80
        [fL,dfL] = pressureFunction(pOld,rhoL,pL,aL,gamma);
        [fR,dfR] = pressureFunction(pOld,rhoR,pR,aR,gamma);

        pNew = pOld - (fL+fR+uR-uL)/(dfL+dfR);
        pNew = max(pNew,1e-10);

        if abs(pNew-pOld)/(0.5*(pNew+pOld)) < 1e-10
            break;
        end
        pOld = pNew;
    end

    pStar = pNew;
    [fL,~] = pressureFunction(pStar,rhoL,pL,aL,gamma);
    [fR,~] = pressureFunction(pStar,rhoR,pR,aR,gamma);
    uStar = 0.5*(uL+uR+fR-fL);

    info.pStar = pStar;
    info.uStar = uStar;
    info.aL = aL;
    info.aR = aR;

    if pStar > pL
        info.leftWave = 'shock';
    else
        info.leftWave = 'rarefaction';
    end

    if pStar > pR
        info.rightWave = 'shock';
    else
        info.rightWave = 'rarefaction';
    end
end

function [f,df] = pressureFunction(p,rhoK,pK,aK,gamma)
    if p > pK
        AK = 2/((gamma+1)*rhoK);
        BK = (gamma-1)/(gamma+1)*pK;
        f  = (p-pK)*sqrt(AK/(p+BK));
        df = sqrt(AK/(p+BK))*(1-0.5*(p-pK)/(p+BK));
    else
        pr = p/pK;
        f  = 2*aK/(gamma-1)*(pr^((gamma-1)/(2*gamma))-1);
        df = 1/(rhoK*aK)*pr^(-(gamma+1)/(2*gamma));
    end
end

function [rho,u,p,M] = exactSolution(x,t,x0,leftState,rightState,gamma)
    rhoL = leftState(1); uL = leftState(2); pL = leftState(3);
    rhoR = rightState(1); uR = rightState(2); pR = rightState(3);

    info = exactRiemannInfo(leftState,rightState,gamma);
    pStar = info.pStar;
    uStar = info.uStar;
    aL = info.aL;
    aR = info.aR;

    xi = (x-x0)./max(t,1e-14);

    rho = zeros(size(x));
    u   = zeros(size(x));
    p   = zeros(size(x));

    for i = 1:length(x)
        s = xi(i);

        if s <= uStar
            % 左侧波
            if pStar > pL
                % 左激波
                SL = uL - aL*sqrt((gamma+1)/(2*gamma)*pStar/pL + (gamma-1)/(2*gamma));
                if s <= SL
                    rho(i)=rhoL; u(i)=uL; p(i)=pL;
                else
                    rhoStarL = rhoL*((pStar/pL + (gamma-1)/(gamma+1)) / ...
                        ((gamma-1)/(gamma+1)*pStar/pL + 1));
                    rho(i)=rhoStarL; u(i)=uStar; p(i)=pStar;
                end
            else
                % 左膨胀波
                aStarL = aL*(pStar/pL)^((gamma-1)/(2*gamma));
                SHL = uL - aL;
                STL = uStar - aStarL;

                if s <= SHL
                    rho(i)=rhoL; u(i)=uL; p(i)=pL;
                elseif s > STL
                    rhoStarL = rhoL*(pStar/pL)^(1/gamma);
                    rho(i)=rhoStarL; u(i)=uStar; p(i)=pStar;
                else
                    uFan = 2/(gamma+1)*(aL + 0.5*(gamma-1)*uL + s);
                    aFan = 2/(gamma+1)*(aL + 0.5*(gamma-1)*(uL-s));
                    rhoFan = rhoL*(aFan/aL)^(2/(gamma-1));
                    pFan = pL*(aFan/aL)^(2*gamma/(gamma-1));
                    rho(i)=rhoFan; u(i)=uFan; p(i)=pFan;
                end
            end
        else
            % 右侧波
            if pStar > pR
                % 右激波
                SR = uR + aR*sqrt((gamma+1)/(2*gamma)*pStar/pR + (gamma-1)/(2*gamma));
                if s >= SR
                    rho(i)=rhoR; u(i)=uR; p(i)=pR;
                else
                    rhoStarR = rhoR*((pStar/pR + (gamma-1)/(gamma+1)) / ...
                        ((gamma-1)/(gamma+1)*pStar/pR + 1));
                    rho(i)=rhoStarR; u(i)=uStar; p(i)=pStar;
                end
            else
                % 右膨胀波
                aStarR = aR*(pStar/pR)^((gamma-1)/(2*gamma));
                SHR = uR + aR;
                STR = uStar + aStarR;

                if s >= SHR
                    rho(i)=rhoR; u(i)=uR; p(i)=pR;
                elseif s <= STR
                    rhoStarR = rhoR*(pStar/pR)^(1/gamma);
                    rho(i)=rhoStarR; u(i)=uStar; p(i)=pStar;
                else
                    uFan = 2/(gamma+1)*(-aR + 0.5*(gamma-1)*uR + s);
                    aFan = 2/(gamma+1)*(aR - 0.5*(gamma-1)*(uR-s));
                    rhoFan = rhoR*(aFan/aR)^(2/(gamma-1));
                    pFan = pR*(aFan/aR)^(2*gamma/(gamma-1));
                    rho(i)=rhoFan; u(i)=uFan; p(i)=pFan;
                end
            end
        end
    end

    M = u./sqrt(gamma*p./rho);
end

function lim = expandLimits(y)
    ymin = min(y); ymax = max(y);
    if abs(ymax-ymin) < 1e-12
        pad = 0.1*max(1,abs(ymax));
    else
        pad = 0.10*(ymax-ymin);
    end
    lim = [ymin-pad, ymax+pad];
end

function plotXTDiagram(xMin,xMax,x0,tEnd,leftState,rightState,gamma,info)
    rhoL = leftState(1); uL = leftState(2); pL = leftState(3);
    rhoR = rightState(1); uR = rightState(2); pR = rightState(3);

    aL = sqrt(gamma*pL/rhoL);
    aR = sqrt(gamma*pR/rhoR);
    pStar = info.pStar;
    uStar = info.uStar;

    figure('Color','w','Position',[120 120 900 620]);
    hold on; box on; grid on;
    xlabel('x'); ylabel('t');
    title('x-t diagram of wave structure and characteristic families');

    tt = linspace(0,tEnd,200);

    % 接触间断
    plot(x0 + uStar*tt,tt,'LineWidth',2.2,'DisplayName','contact: dx/dt=u_*');

    % 左侧波
    if strcmp(info.leftWave,'shock')
        SL = uL - aL*sqrt((gamma+1)/(2*gamma)*pStar/pL + (gamma-1)/(2*gamma));
        plot(x0 + SL*tt,tt,'LineWidth',2.2,'DisplayName','left shock');
    else
        aStarL = aL*(pStar/pL)^((gamma-1)/(2*gamma));
        SHL = uL - aL;
        STL = uStar - aStarL;
        plot(x0 + SHL*tt,tt,'LineWidth',2.0,'DisplayName','left rarefaction head');
        plot(x0 + STL*tt,tt,'LineWidth',2.0,'DisplayName','left rarefaction tail');

        % 膨胀扇内部特征线
        speeds = linspace(SHL,STL,9);
        for s = speeds
            plot(x0 + s*tt,tt,':','LineWidth',1.0,'HandleVisibility','off');
        end
    end

    % 右侧波
    if strcmp(info.rightWave,'shock')
        SR = uR + aR*sqrt((gamma+1)/(2*gamma)*pStar/pR + (gamma-1)/(2*gamma));
        plot(x0 + SR*tt,tt,'LineWidth',2.2,'DisplayName','right shock');
    else
        aStarR = aR*(pStar/pR)^((gamma-1)/(2*gamma));
        SHR = uR + aR;
        STR = uStar + aStarR;
        plot(x0 + STR*tt,tt,'LineWidth',2.0,'DisplayName','right rarefaction tail');
        plot(x0 + SHR*tt,tt,'LineWidth',2.0,'DisplayName','right rarefaction head');

        speeds = linspace(STR,SHR,9);
        for s = speeds
            plot(x0 + s*tt,tt,':','LineWidth',1.0,'HandleVisibility','off');
        end
    end

    xlim([xMin,xMax]);
    ylim([0,tEnd]);
    legend('Location','bestoutside');

    text(x0,0,' initial discontinuity','VerticalAlignment','bottom');
    annotationText = sprintf('\\lambda_1=u-a,   \\lambda_2=u,   \\lambda_3=u+a');
    text(xMin+0.03*(xMax-xMin),0.92*tEnd,annotationText,'FontSize',12);
end
