%% shock_tube_fd_exact_demo_xy_contour_combined.m
% 一维 Euler 激波管：线图动画 + 当前时刻 xy 伪二维场动画
% -------------------------------------------------------------------------
%
% 教学目的：
% 1. 在线图中保留初始条件、数值解和精确解；
% 2. 在同一个 figure 中同步显示当前时刻的二维 x-y 场；
% 3. 时间 t 不作为纵坐标，而是通过动画推进体现；
% 4. 二维场是将一维解沿 y 方向复制，用于课堂可视化。
%
% 说明：
% - 激波管本质是一维问题，因此 x-y 图中的 y 方向没有真实物理变化；
% - 这样处理的目的，是让学生直观看到间断面、膨胀区和平台区随时间移动；
% - 数值格式默认使用 LF/Rusanov 型守恒通量，可切换 MacCormack 观察振荡。

clear; clc; close all;

%% ======================= 用户可调参数 =======================
caseName = 'Sod';          % 'Sod' | 'Lax' | 'StrongShock' | 'DoubleRarefaction'
method   = 'LF';           % 'LF' | 'Rusanov' | 'MacCormack'
gamma    = 1.4;

N        = 500;            % x 方向网格数
Ny       = 70;             % y 方向显示网格数，仅用于二维可视化
xMin     = 0.0;
xMax     = 1.0;
yMin     = 0.0;
yMax     = 0.20;
x0       = 0.5;            % 初始间断位置
CFL      = 0.55;
tEnd     = 0.20;
plotEvery = 6;             % 每隔多少步画一次
showXT   = true;           % 是否另画 x-t 波系示意
showInitialReference = true; % 线图中是否保留初始条件作为参照
showExactLine = true;      % 是否在线图中叠加精确解

% y轴是否自动变化；false 更适合课堂观察动态
% 注意：这里控制线图的纵轴范围，不影响二维图色标范围
autoY = false;

%% ======================= 初始条件 =======================
[leftState,rightState,tEndDefault] = getShockTubeCase(caseName);
if isempty(tEnd)
    tEnd = tEndDefault;
end

rhoL = leftState(1); uL = leftState(2); pL = leftState(3);
rhoR = rightState(1); uR = rightState(2); pR = rightState(3);

x  = linspace(xMin,xMax,N)';
dx = x(2)-x(1);
y  = linspace(yMin,yMax,Ny)';
[X2,Y2] = meshgrid(x,y);

rho = rhoL*(x<x0) + rhoR*(x>=x0);
u   = uL  *(x<x0) + uR  *(x>=x0);
p   = pL  *(x<x0) + pR  *(x>=x0);

U = prim2cons(rho,u,p,gamma);

% 保存初始条件，用于动画中作为固定参照线
rho0 = rho;
u0   = u;
p0   = p;
M0   = u0./sqrt(gamma*p0./rho0);

%% ======================= 预计算精确波系 =======================
exactInfo = exactRiemannInfo(leftState,rightState,gamma);

fprintf('\nCase: %s\n',caseName);
fprintf('Method: %s\n',method);
fprintf('p_star = %.8f, u_star = %.8f\n',exactInfo.pStar,exactInfo.uStar);
fprintf('Left wave : %s\n',exactInfo.leftWave);
fprintf('Right wave: %s\n\n',exactInfo.rightWave);

% 用终止时刻精确解确定线图范围和二维图色标范围
[rhoEnd,uEnd,pEnd,MEnd] = exactSolution(x,tEnd,x0,leftState,rightState,gamma);

rhoLim = expandLimits([rhoEnd;rho0]);
uLim   = expandLimits([uEnd;u0]);
pLim   = expandLimits([pEnd;p0]);
MLim   = expandLimits([MEnd;M0]);

%% ======================= 作图初始化：线图 + xy 二维场同屏 =======================
fig = figure('Color','w','Position',[60 40 1550 980]);
tl = tiledlayout(fig,4,2,'TileSpacing','compact','Padding','compact');

% ---------- rho ----------
axRhoLine = nexttile(tl,1); hold(axRhoLine,'on'); box(axRhoLine,'on');
hRhoInit = plot(axRhoLine,x,rho0,':','LineWidth',1.2);
hRhoNum  = plot(axRhoLine,x,rho,'LineWidth',1.7);
hRhoEx   = plot(axRhoLine,x,rho,'--','LineWidth',1.3);
ylabel(axRhoLine,'\rho'); title(axRhoLine,'Density: line view');
if ~showInitialReference, set(hRhoInit,'Visible','off'); end
if ~showExactLine, set(hRhoEx,'Visible','off'); end
legend(axRhoLine,legendItems(showInitialReference,showExactLine),'Location','best');
if ~autoY, ylim(axRhoLine,rhoLim); end

axRho2D = nexttile(tl,2); hold(axRho2D,'on'); box(axRho2D,'on');
hRho2D = imagesc(axRho2D,x,y,field2D(rho,Ny));
set(axRho2D,'YDir','normal');
hRhoX0 = plot(axRho2D,[x0 x0],[yMin yMax],':','LineWidth',1.0);
colorbar(axRho2D); climCompat(axRho2D,rhoLim);
ylabel(axRho2D,'y'); title(axRho2D,'Density: current xy field');

% ---------- u ----------
axULine = nexttile(tl,3); hold(axULine,'on'); box(axULine,'on');
hUInit = plot(axULine,x,u0,':','LineWidth',1.2);
hUNum  = plot(axULine,x,u,'LineWidth',1.7);
hUEx   = plot(axULine,x,u,'--','LineWidth',1.3);
ylabel(axULine,'u'); title(axULine,'Velocity: line view');
if ~showInitialReference, set(hUInit,'Visible','off'); end
if ~showExactLine, set(hUEx,'Visible','off'); end
if ~autoY, ylim(axULine,uLim); end

axU2D = nexttile(tl,4); hold(axU2D,'on'); box(axU2D,'on');
hU2D = imagesc(axU2D,x,y,field2D(u,Ny));
set(axU2D,'YDir','normal');
hUX0 = plot(axU2D,[x0 x0],[yMin yMax],':','LineWidth',1.0);
colorbar(axU2D); climCompat(axU2D,uLim);
ylabel(axU2D,'y'); title(axU2D,'Velocity: current xy field');

% ---------- p ----------
axPLine = nexttile(tl,5); hold(axPLine,'on'); box(axPLine,'on');
hPInit = plot(axPLine,x,p0,':','LineWidth',1.2);
hPNum  = plot(axPLine,x,p,'LineWidth',1.7);
hPEx   = plot(axPLine,x,p,'--','LineWidth',1.3);
ylabel(axPLine,'p'); title(axPLine,'Pressure: line view');
if ~showInitialReference, set(hPInit,'Visible','off'); end
if ~showExactLine, set(hPEx,'Visible','off'); end
if ~autoY, ylim(axPLine,pLim); end

axP2D = nexttile(tl,6); hold(axP2D,'on'); box(axP2D,'on');
hP2D = imagesc(axP2D,x,y,field2D(p,Ny));
set(axP2D,'YDir','normal');
hPX0 = plot(axP2D,[x0 x0],[yMin yMax],':','LineWidth',1.0);
colorbar(axP2D); climCompat(axP2D,pLim);
ylabel(axP2D,'y'); title(axP2D,'Pressure: current xy field');

% ---------- M ----------
a = sqrt(gamma*p./rho);
M = u./a;
axMLine = nexttile(tl,7); hold(axMLine,'on'); box(axMLine,'on');
hMInit = plot(axMLine,x,M0,':','LineWidth',1.2);
hMNum  = plot(axMLine,x,M,'LineWidth',1.7);
hMEx   = plot(axMLine,x,M,'--','LineWidth',1.3);
ylabel(axMLine,'M'); xlabel(axMLine,'x'); title(axMLine,'Mach number: line view');
if ~showInitialReference, set(hMInit,'Visible','off'); end
if ~showExactLine, set(hMEx,'Visible','off'); end
if ~autoY, ylim(axMLine,MLim); end

axM2D = nexttile(tl,8); hold(axM2D,'on'); box(axM2D,'on');
hM2D = imagesc(axM2D,x,y,field2D(M,Ny));
set(axM2D,'YDir','normal');
hMX0 = plot(axM2D,[x0 x0],[yMin yMax],':','LineWidth',1.0);
colorbar(axM2D); climCompat(axM2D,MLim);
ylabel(axM2D,'y'); xlabel(axM2D,'x'); title(axM2D,'Mach number: current xy field');

% 统一 x/y 范围
allAxes2D = [axRho2D,axU2D,axP2D,axM2D];
for ax = allAxes2D
    xlim(ax,[xMin xMax]); ylim(ax,[yMin yMax]);
end
allLineAxes = [axRhoLine,axULine,axPLine,axMLine];
for ax = allLineAxes
    xlim(ax,[xMin xMax]); grid(ax,'on');
end

sgtitleText = sprintf('1D Euler shock tube: %s, %s, t = %.4f',caseName,method,0);
title(tl,sgtitleText,'FontWeight','bold');

%% ======================= 动态时间推进 =======================
t = 0.0;
step = 0;

while t < tEnd
    [rho,u,p] = cons2prim(U,gamma);
    a = sqrt(gamma*p./rho);
    dt = CFL*dx/max(abs(u)+a);
    if t+dt > tEnd
        dt = tEnd-t;
    end

    switch lower(method)
        case {'lf','rusanov'}
            U = stepRusanov(U,gamma,dx,dt);
        case 'maccormack'
            U = stepMacCormack(U,gamma,dx,dt);
        otherwise
            error('Unknown method: %s',method);
    end

    % 简单透射边界
    U(1,:)   = U(2,:);
    U(end,:) = U(end-1,:);

    t = t+dt;
    step = step+1;

    if mod(step,plotEvery)==0 || t>=tEnd
        [rho,u,p] = cons2prim(U,gamma);
        a = sqrt(gamma*p./rho);
        M = u./a;

        [rhoEx,uEx,pEx,MEx] = exactSolution(x,t,x0,leftState,rightState,gamma);

        % 更新线图
        set(hRhoNum,'YData',rho); set(hRhoEx,'YData',rhoEx);
        set(hUNum,  'YData',u);   set(hUEx,  'YData',uEx);
        set(hPNum,  'YData',p);   set(hPEx,  'YData',pEx);
        set(hMNum,  'YData',M);   set(hMEx,  'YData',MEx);

        % 更新当前时刻 xy 二维场：只更新 CData，不删除图形对象。
        % 这样可以避免初始时刻/局部常值场导致 contourf 不生成对象的问题。
        set(hRho2D,'CData',field2D(rho,Ny));
        set(hU2D,  'CData',field2D(u,Ny));
        set(hP2D,  'CData',field2D(p,Ny));
        set(hM2D,  'CData',field2D(M,Ny));

        % 将初始间断位置参考线放在二维场上方
        uistack(hRhoX0,'top'); uistack(hUX0,'top');
        uistack(hPX0,'top');   uistack(hMX0,'top');

        % 重设色标和范围，避免更新后显示范围改变
        climCompat(axRho2D,rhoLim); climCompat(axU2D,uLim);
        climCompat(axP2D,pLim);     climCompat(axM2D,MLim);
        for ax = allAxes2D
            xlim(ax,[xMin xMax]); ylim(ax,[yMin yMax]);
            set(ax,'YDir','normal');
        end

        title(tl,sprintf('1D Euler shock tube: %s, %s, t = %.4f',caseName,method,t), ...
            'FontWeight','bold');
        drawnow;
    end
end

%% ======================= x-t 波系示意 =======================
if showXT
    plotXTDiagram(xMin,xMax,x0,tEnd,leftState,rightState,gamma,exactInfo);
end

%% ========================================================================
%                              局部函数
% ========================================================================


function Z = field2D(q,Ny)
    % 将一维变量 q(x) 沿 y 方向复制，形成用于 x-y 展示的二维场。
    Z = repmat(q(:).',Ny,1);
end

function items = legendItems(showInitialReference,showExactLine)
    items = {};
    if showInitialReference
        items{end+1} = 'initial';
    end
    items{end+1} = 'numerical';
    if showExactLine
        items{end+1} = 'exact';
    end
end

function climCompat(ax,lim)
    % 兼容不同 MATLAB 版本的色标范围设置。
    try
        clim(ax,lim);
    catch
        caxis(ax,lim);
    end
end

function [leftState,rightState,tEnd] = getShockTubeCase(caseName)
    switch lower(caseName)
        case 'sod'
            leftState  = [1.0,   0.0, 1.0];
            rightState = [0.125, 0.0, 0.1];
            tEnd = 0.20;
        case 'lax'
            leftState  = [0.445, 0.698, 3.528];
            rightState = [0.500, 0.000, 0.571];
            tEnd = 0.16;
        case 'strongshock'
            leftState  = [1.0, 0.0, 1000.0];
            rightState = [1.0, 0.0, 0.01];
            tEnd = 0.012;
        case 'doublerarefaction'
            leftState  = [1.0, -2.0, 0.4];
            rightState = [1.0,  2.0, 0.4];
            tEnd = 0.15;
        otherwise
            error('Unknown caseName.');
    end
end

function U = prim2cons(rho,u,p,gamma)
    E = p./((gamma-1)*rho) + 0.5*u.^2;
    U = [rho, rho.*u, rho.*E];
end

function [rho,u,p] = cons2prim(U,gamma)
    rho = U(:,1);
    u   = U(:,2)./rho;
    E   = U(:,3)./rho;
    p   = (gamma-1)*rho.*(E-0.5*u.^2);

    % 防止教学演示中 MacCormack 在强激波附近导致负压后程序崩溃
    rho = max(rho,1e-12);
    p   = max(p,1e-12);
end

function F = fluxEuler(U,gamma)
    [rho,u,p] = cons2prim(U,gamma);
    E = U(:,3)./rho;
    F = [rho.*u, rho.*u.^2+p, u.*(rho.*E+p)];
end

function Unew = stepRusanov(U,gamma,dx,dt)
    N = size(U,1);
    F = fluxEuler(U,gamma);
    numFlux = zeros(N-1,3);

    for i = 1:N-1
        UL = U(i,:);
        UR = U(i+1,:);
        FL = F(i,:);
        FR = F(i+1,:);

        [rhoL,uL,pL] = cons2prim(UL,gamma);
        [rhoR,uR,pR] = cons2prim(UR,gamma);
        aL = sqrt(gamma*pL/rhoL);
        aR = sqrt(gamma*pR/rhoR);
        smax = max(abs(uL)+aL,abs(uR)+aR);

        numFlux(i,:) = 0.5*(FL+FR) - 0.5*smax*(UR-UL);
    end

    Unew = U;
    for i = 2:N-1
        Unew(i,:) = U(i,:) - dt/dx*(numFlux(i,:)-numFlux(i-1,:));
    end
end

function Unew = stepMacCormack(U,gamma,dx,dt)
    N = size(U,1);

    % predictor：前向差分
    F = fluxEuler(U,gamma);
    Up = U;
    for i = 2:N-1
        Up(i,:) = U(i,:) - dt/dx*(F(i+1,:)-F(i,:));
    end
    Up(1,:) = Up(2,:);
    Up(N,:) = Up(N-1,:);

    % corrector：后向差分
    Fp = fluxEuler(Up,gamma);
    Unew = U;
    for i = 2:N-1
        Unew(i,:) = 0.5*(U(i,:) + Up(i,:) - dt/dx*(Fp(i,:)-Fp(i-1,:)));
    end

    % 少量人工粘性，防止课堂演示中直接爆掉；可设为 0 观察振荡
    epsAV = 0.02;
    Unew(2:N-1,:) = Unew(2:N-1,:) + epsAV*(U(3:N,:)-2*U(2:N-1,:)+U(1:N-2,:));
end

function info = exactRiemannInfo(leftState,rightState,gamma)
    rhoL = leftState(1); uL = leftState(2); pL = leftState(3);
    rhoR = rightState(1); uR = rightState(2); pR = rightState(3);

    aL = sqrt(gamma*pL/rhoL);
    aR = sqrt(gamma*pR/rhoR);

    pPV = 0.5*(pL+pR) - 0.125*(uR-uL)*(rhoL+rhoR)*(aL+aR);
    pOld = max(1e-8,pPV);

    for k = 1:80
        [fL,dfL] = pressureFunction(pOld,rhoL,pL,aL,gamma);
        [fR,dfR] = pressureFunction(pOld,rhoR,pR,aR,gamma);

        pNew = pOld - (fL+fR+uR-uL)/(dfL+dfR);
        pNew = max(pNew,1e-10);

        if abs(pNew-pOld)/(0.5*(pNew+pOld)) < 1e-10
            break;
        end
        pOld = pNew;
    end

    pStar = pNew;
    [fL,~] = pressureFunction(pStar,rhoL,pL,aL,gamma);
    [fR,~] = pressureFunction(pStar,rhoR,pR,aR,gamma);
    uStar = 0.5*(uL+uR+fR-fL);

    info.pStar = pStar;
    info.uStar = uStar;
    info.aL = aL;
    info.aR = aR;

    if pStar > pL
        info.leftWave = 'shock';
    else
        info.leftWave = 'rarefaction';
    end

    if pStar > pR
        info.rightWave = 'shock';
    else
        info.rightWave = 'rarefaction';
    end
end

function [f,df] = pressureFunction(p,rhoK,pK,aK,gamma)
    if p > pK
        AK = 2/((gamma+1)*rhoK);
        BK = (gamma-1)/(gamma+1)*pK;
        f  = (p-pK)*sqrt(AK/(p+BK));
        df = sqrt(AK/(p+BK))*(1-0.5*(p-pK)/(p+BK));
    else
        pr = p/pK;
        f  = 2*aK/(gamma-1)*(pr^((gamma-1)/(2*gamma))-1);
        df = 1/(rhoK*aK)*pr^(-(gamma+1)/(2*gamma));
    end
end

function [rho,u,p,M] = exactSolution(x,t,x0,leftState,rightState,gamma)
    rhoL = leftState(1); uL = leftState(2); pL = leftState(3);
    rhoR = rightState(1); uR = rightState(2); pR = rightState(3);

    info = exactRiemannInfo(leftState,rightState,gamma);
    pStar = info.pStar;
    uStar = info.uStar;
    aL = info.aL;
    aR = info.aR;

    xi = (x-x0)./max(t,1e-14);

    rho = zeros(size(x));
    u   = zeros(size(x));
    p   = zeros(size(x));

    for i = 1:length(x)
        s = xi(i);

        if s <= uStar
            % 左侧波
            if pStar > pL
                % 左激波
                SL = uL - aL*sqrt((gamma+1)/(2*gamma)*pStar/pL + (gamma-1)/(2*gamma));
                if s <= SL
                    rho(i)=rhoL; u(i)=uL; p(i)=pL;
                else
                    rhoStarL = rhoL*((pStar/pL + (gamma-1)/(gamma+1)) / ...
                        ((gamma-1)/(gamma+1)*pStar/pL + 1));
                    rho(i)=rhoStarL; u(i)=uStar; p(i)=pStar;
                end
            else
                % 左膨胀波
                aStarL = aL*(pStar/pL)^((gamma-1)/(2*gamma));
                SHL = uL - aL;
                STL = uStar - aStarL;

                if s <= SHL
                    rho(i)=rhoL; u(i)=uL; p(i)=pL;
                elseif s > STL
                    rhoStarL = rhoL*(pStar/pL)^(1/gamma);
                    rho(i)=rhoStarL; u(i)=uStar; p(i)=pStar;
                else
                    uFan = 2/(gamma+1)*(aL + 0.5*(gamma-1)*uL + s);
                    aFan = 2/(gamma+1)*(aL + 0.5*(gamma-1)*(uL-s));
                    rhoFan = rhoL*(aFan/aL)^(2/(gamma-1));
                    pFan = pL*(aFan/aL)^(2*gamma/(gamma-1));
                    rho(i)=rhoFan; u(i)=uFan; p(i)=pFan;
                end
            end
        else
            % 右侧波
            if pStar > pR
                % 右激波
                SR = uR + aR*sqrt((gamma+1)/(2*gamma)*pStar/pR + (gamma-1)/(2*gamma));
                if s >= SR
                    rho(i)=rhoR; u(i)=uR; p(i)=pR;
                else
                    rhoStarR = rhoR*((pStar/pR + (gamma-1)/(gamma+1)) / ...
                        ((gamma-1)/(gamma+1)*pStar/pR + 1));
                    rho(i)=rhoStarR; u(i)=uStar; p(i)=pStar;
                end
            else
                % 右膨胀波
                aStarR = aR*(pStar/pR)^((gamma-1)/(2*gamma));
                SHR = uR + aR;
                STR = uStar + aStarR;

                if s >= SHR
                    rho(i)=rhoR; u(i)=uR; p(i)=pR;
                elseif s <= STR
                    rhoStarR = rhoR*(pStar/pR)^(1/gamma);
                    rho(i)=rhoStarR; u(i)=uStar; p(i)=pStar;
                else
                    uFan = 2/(gamma+1)*(-aR + 0.5*(gamma-1)*uR + s);
                    aFan = 2/(gamma+1)*(aR - 0.5*(gamma-1)*(uR-s));
                    rhoFan = rhoR*(aFan/aR)^(2/(gamma-1));
                    pFan = pR*(aFan/aR)^(2*gamma/(gamma-1));
                    rho(i)=rhoFan; u(i)=uFan; p(i)=pFan;
                end
            end
        end
    end

    M = u./sqrt(gamma*p./rho);
end

function lim = expandLimits(y)
    ymin = min(y); ymax = max(y);
    if abs(ymax-ymin) < 1e-12
        pad = 0.1*max(1,abs(ymax));
    else
        pad = 0.10*(ymax-ymin);
    end
    lim = [ymin-pad, ymax+pad];
end

function plotXTDiagram(xMin,xMax,x0,tEnd,leftState,rightState,gamma,info)
    rhoL = leftState(1); uL = leftState(2); pL = leftState(3);
    rhoR = rightState(1); uR = rightState(2); pR = rightState(3);

    aL = sqrt(gamma*pL/rhoL);
    aR = sqrt(gamma*pR/rhoR);
    pStar = info.pStar;
    uStar = info.uStar;

    figure('Color','w','Position',[120 120 900 620]);
    hold on; box on; grid on;
    xlabel('x'); ylabel('t');
    title('x-t diagram of wave structure and characteristic families');

    tt = linspace(0,tEnd,200);

    % 接触间断
    plot(x0 + uStar*tt,tt,'LineWidth',2.2,'DisplayName','contact: dx/dt=u_*');

    % 左侧波
    if strcmp(info.leftWave,'shock')
        SL = uL - aL*sqrt((gamma+1)/(2*gamma)*pStar/pL + (gamma-1)/(2*gamma));
        plot(x0 + SL*tt,tt,'LineWidth',2.2,'DisplayName','left shock');
    else
        aStarL = aL*(pStar/pL)^((gamma-1)/(2*gamma));
        SHL = uL - aL;
        STL = uStar - aStarL;
        plot(x0 + SHL*tt,tt,'LineWidth',2.0,'DisplayName','left rarefaction head');
        plot(x0 + STL*tt,tt,'LineWidth',2.0,'DisplayName','left rarefaction tail');

        % 膨胀扇内部特征线
        speeds = linspace(SHL,STL,9);
        for s = speeds
            plot(x0 + s*tt,tt,':','LineWidth',1.0,'HandleVisibility','off');
        end
    end

    % 右侧波
    if strcmp(info.rightWave,'shock')
        SR = uR + aR*sqrt((gamma+1)/(2*gamma)*pStar/pR + (gamma-1)/(2*gamma));
        plot(x0 + SR*tt,tt,'LineWidth',2.2,'DisplayName','right shock');
    else
        aStarR = aR*(pStar/pR)^((gamma-1)/(2*gamma));
        SHR = uR + aR;
        STR = uStar + aStarR;
        plot(x0 + STR*tt,tt,'LineWidth',2.0,'DisplayName','right rarefaction tail');
        plot(x0 + SHR*tt,tt,'LineWidth',2.0,'DisplayName','right rarefaction head');

        speeds = linspace(STR,SHR,9);
        for s = speeds
            plot(x0 + s*tt,tt,':','LineWidth',1.0,'HandleVisibility','off');
        end
    end

    xlim([xMin,xMax]);
    ylim([0,tEnd]);
    legend('Location','bestoutside');

    text(x0,0,' initial discontinuity','VerticalAlignment','bottom');
    annotationText = sprintf('\\lambda_1=u-a,   \\lambda_2=u,   \\lambda_3=u+a');
    text(xMin+0.03*(xMax-xMin),0.92*tEnd,annotationText,'FontSize',12);
end
