%% riemann2d_euler_contour_demo.m
% 二维 Euler 方程四象限 Riemann 问题动态演示
% -------------------------------------------------------------------------
% 教学目标：
% 1. 展示二维 Riemann 问题中激波、接触间断、滑移线和膨胀波的相互作用；
% 2. 用动态 contour 观察二维波系如何从四象限初始间断中发展出来；
% 3. 通过切换不同初始条件，比较不同波系结构。
%
% 控制方程：二维无粘可压缩 Euler 方程
%
%     U_t + F(U)_x + G(U)_y = 0
%
% 数值方法：
% - 有限体积思想；
% - x/y 方向均采用 Rusanov / local Lax-Friedrichs 通量；
% - 一阶格式，鲁棒稳定，适合教学演示；
% - 网格可适当加密以显示更细结构。
%
% 说明：
% - 一阶 Rusanov 格式有数值耗散，结构会被一定程度抹宽；
% - 若要展示更锐利结构，可进一步扩展 MUSCL、WENO 或 HLLC 通量；
% - 本脚本重在课堂直观演示二维 Riemann 问题的波系演化。
%
% 推荐运行方式：
% 先用 caseName='Config12' 或 'Config3'；
% 再尝试 'Implosion'、'Explosion'、'KHLike'。
%
% 作者：ChatGPT for CFD teaching demo

clear; clc; close all;

%% ====================== 用户可调参数 ======================
caseName = 'Config12';
% 可选：
% 'Config3'
% 'Config4'
% 'Config6'
% 'Config12'
% 'Implosion'
% 'Explosion'
% 'KHLike'

gamma = 1.4;

% 网格数：演示建议 300~500；若电脑较慢可降到 200
Nx = 420;
Ny = 420;

xMin = 0.0; xMax = 1.0;
yMin = 0.0; yMax = 1.0;

x0 = 0.5;
y0 = 0.5;

CFL = 0.42;
tEnd = 0.25;

plotEvery = 5;
contourLevels = 80;

% 动画变量：'rho' | 'p' | 'speed' | 'mach' | 'schlieren'
plotMode = 'rho';

% 是否同时显示四个变量
showFourPanels = true;

% 是否叠加初始四象限分界线
showInitialLines = true;

% 是否保存动画为 GIF
saveGif = false;
gifName = ['riemann2d_',caseName,'_',plotMode,'.gif'];

%% ====================== 网格 ======================
x = linspace(xMin,xMax,Nx);
y = linspace(yMin,yMax,Ny);
dx = x(2)-x(1);
dy = y(2)-y(1);

[X,Y] = meshgrid(x,y);

%% ====================== 初始条件 ======================
% 四象限记号：
%
%      Q2 | Q1
%     ----+----
%      Q3 | Q4
%
% Q1: x>x0, y>y0, 右上
% Q2: x<x0, y>y0, 左上
% Q3: x<x0, y<y0, 左下
% Q4: x>x0, y<y0, 右下

[Q1,Q2,Q3,Q4,tEndDefault,caseDescription] = getRiemann2DCase(caseName);

if isempty(tEnd)
    tEnd = tEndDefault;
end

rho = zeros(Ny,Nx);
u   = zeros(Ny,Nx);
v   = zeros(Ny,Nx);
p   = zeros(Ny,Nx);

maskQ1 = (X>=x0) & (Y>=y0);
maskQ2 = (X< x0) & (Y>=y0);
maskQ3 = (X< x0) & (Y< y0);
maskQ4 = (X>=x0) & (Y< y0);

[rho,u,v,p] = assignQuadrant(rho,u,v,p,maskQ1,Q1);
[rho,u,v,p] = assignQuadrant(rho,u,v,p,maskQ2,Q2);
[rho,u,v,p] = assignQuadrant(rho,u,v,p,maskQ3,Q3);
[rho,u,v,p] = assignQuadrant(rho,u,v,p,maskQ4,Q4);

U = prim2cons2D(rho,u,v,p,gamma);

fprintf('\n============================================================\n');
fprintf('2D Riemann problem demo\n');
fprintf('Case: %s\n',caseName);
fprintf('%s\n',caseDescription);
fprintf('Grid: %d x %d\n',Nx,Ny);
fprintf('tEnd = %.4f, CFL = %.3f\n',tEnd,CFL);
fprintf('============================================================\n\n');

%% ====================== 图形初始化 ======================
fig = figure('Color','w','Position',[60 60 1250 930]);

if showFourPanels
    tl = tiledlayout(fig,2,2,'TileSpacing','compact','Padding','compact');

    axRho = nexttile;
    hRho = imagesc(axRho,x,y,rho);
    set(axRho,'YDir','normal'); axis(axRho,'equal','tight');
    title(axRho,'Density \rho');
    colorbar(axRho);
    hold(axRho,'on');

    axP = nexttile;
    hP = imagesc(axP,x,y,p);
    set(axP,'YDir','normal'); axis(axP,'equal','tight');
    title(axP,'Pressure p');
    colorbar(axP);
    hold(axP,'on');

    axSpeed = nexttile;
    speed = sqrt(u.^2+v.^2);
    hSpeed = imagesc(axSpeed,x,y,speed);
    set(axSpeed,'YDir','normal'); axis(axSpeed,'equal','tight');
    title(axSpeed,'Speed |\bfu|');
    colorbar(axSpeed);
    hold(axSpeed,'on');

    axSch = nexttile;
    sch = syntheticSchlieren(rho,dx,dy);
    hSch = imagesc(axSch,x,y,sch);
    set(axSch,'YDir','normal'); axis(axSch,'equal','tight');
    title(axSch,'Synthetic schlieren |\nabla \rho|');
    colorbar(axSch);
    hold(axSch,'on');

    axs = [axRho, axP, axSpeed, axSch];

else
    ax = axes(fig);
    field = selectField(plotMode,rho,u,v,p,gamma,dx,dy);
    hField = imagesc(ax,x,y,field);
    set(ax,'YDir','normal'); axis(ax,'equal','tight');
    colorbar(ax);
    title(ax,fieldTitle(plotMode));
    hold(ax,'on');
    axs = ax;
end

for axNow = axs
    xlabel(axNow,'x');
    ylabel(axNow,'y');
    if showInitialLines
        plot(axNow,[x0 x0],[yMin yMax],'k--','LineWidth',0.9);
        plot(axNow,[xMin xMax],[y0 y0],'k--','LineWidth',0.9);
    end
end

sgtitle(sprintf('2D Euler Riemann problem: %s, t = %.4f',caseName,0.0), ...
    'FontWeight','bold');

drawnow;

%% ====================== 时间推进 ======================
t = 0.0;
step = 0;

while t < tEnd
    [rho,u,v,p] = cons2prim2D(U,gamma);

    a = sqrt(gamma*p./rho);
    maxSpeed = max(abs(u(:))+abs(v(:))+a(:));
    dt = CFL*min(dx,dy)/maxSpeed;

    if t+dt > tEnd
        dt = tEnd-t;
    end

    U = stepRusanov2D(U,gamma,dx,dy,dt);

    % transmissive boundary
    U = applyTransmissiveBC(U);

    t = t+dt;
    step = step+1;

    if mod(step,plotEvery)==0 || t>=tEnd
        [rho,u,v,p] = cons2prim2D(U,gamma);

        if showFourPanels
            speed = sqrt(u.^2+v.^2);
            sch = syntheticSchlieren(rho,dx,dy);

            set(hRho,'CData',rho);
            set(hP,'CData',p);
            set(hSpeed,'CData',speed);
            set(hSch,'CData',sch);

            % 动态颜色范围，避免极端值影响显示
            set(axRho,'CLim',robustCLim(rho));
            set(axP,'CLim',robustCLim(p));
            set(axSpeed,'CLim',robustCLim(speed));
            set(axSch,'CLim',robustCLim(sch));
        else
            field = selectField(plotMode,rho,u,v,p,gamma,dx,dy);
            set(hField,'CData',field);
            set(ax,'CLim',robustCLim(field));
        end

        sgtitle(sprintf('2D Euler Riemann problem: %s, t = %.4f, step = %d', ...
            caseName,t,step),'FontWeight','bold');

        drawnow;

        if saveGif
            frame = getframe(fig);
            [im,map] = rgb2ind(frame2im(frame),256);
            if step == plotEvery
                imwrite(im,map,gifName,'gif','LoopCount',inf,'DelayTime',0.06);
            else
                imwrite(im,map,gifName,'gif','WriteMode','append','DelayTime',0.06);
            end
        end
    end
end

fprintf('Done. Final time t = %.6f, steps = %d\n',t,step);

%% ====================== 最终 contour 图 ======================
figure('Color','w','Position',[120 80 1200 900]);
tl2 = tiledlayout(2,2,'TileSpacing','compact','Padding','compact');

[rho,u,v,p] = cons2prim2D(U,gamma);
speed = sqrt(u.^2+v.^2);
mach = speed./sqrt(gamma*p./rho);
sch = syntheticSchlieren(rho,dx,dy);

nexttile;
contourf(X,Y,rho,contourLevels,'LineColor','none');
axis equal tight; colorbar; title('Final density \rho');
xlabel('x'); ylabel('y');

nexttile;
contourf(X,Y,p,contourLevels,'LineColor','none');
axis equal tight; colorbar; title('Final pressure p');
xlabel('x'); ylabel('y');

nexttile;
contourf(X,Y,mach,contourLevels,'LineColor','none');
axis equal tight; colorbar; title('Final Mach number');
xlabel('x'); ylabel('y');

nexttile;
contourf(X,Y,sch,contourLevels,'LineColor','none');
axis equal tight; colorbar; title('Final synthetic schlieren');
xlabel('x'); ylabel('y');

sgtitle(sprintf('Final fields: %s, t=%.4f',caseName,t),'FontWeight','bold');

%% ========================================================================
%                              局部函数
% ========================================================================

function [Q1,Q2,Q3,Q4,tEnd,desc] = getRiemann2DCase(caseName)
% 状态格式：
% Q = [rho, u, v, p]
%
% 注意：
% 这些是二维 Riemann 问题教学演示用的典型四象限状态。
% 不同文献中对 configuration 编号略有差异，本脚本重点用于课堂观察波系，
% 因此命名为 Config3/4/6/12 等便于切换演示。

    switch lower(caseName)
        case 'config3'
            % 较经典的复杂相互作用结构
            Q1 = [1.5,   0.0,    0.0,    1.5];
            Q2 = [0.5323,1.206,  0.0,    0.3];
            Q3 = [0.138, 1.206,  1.206,  0.029];
            Q4 = [0.5323,0.0,    1.206,  0.3];
            tEnd = 0.25;
            desc = 'Four shocks/contacts interaction; rich wave pattern.';

        case 'config4'
            Q1 = [1.1,   0.0,    0.0,    1.1];
            Q2 = [0.5065,0.8939, 0.0,    0.35];
            Q3 = [1.1,   0.8939, 0.8939, 1.1];
            Q4 = [0.5065,0.0,    0.8939, 0.35];
            tEnd = 0.25;
            desc = 'Mixed shock/contact pattern; suitable for comparing density and schlieren.';

        case 'config6'
            Q1 = [1.0,   0.75,  -0.5,   1.0];
            Q2 = [2.0,   0.75,   0.5,   1.0];
            Q3 = [1.0,  -0.75,   0.5,   1.0];
            Q4 = [3.0,  -0.75,  -0.5,   1.0];
            tEnd = 0.30;
            desc = 'Strong shear and contact interaction; displays fine vortical-like structures.';

        case 'config12'
            % 常用于展示二维黎曼问题中的强相互作用结构
            Q1 = [1.0,    0.7276,  0.0,     1.0];
            Q2 = [0.5313, 0.0,     0.0,     0.4];
            Q3 = [0.8,    0.0,     0.0,     1.0];
            Q4 = [0.5313, 0.0,     0.7276,  0.4];
            tEnd = 0.25;
            desc = 'A standard-looking quadrant case with visible shocks and contacts.';

        case 'implosion'
            % 低压中心/角区产生向内压缩，类似二维内爆演示
            Q1 = [1.0, 0.0, 0.0, 1.0];
            Q2 = [1.0, 0.0, 0.0, 1.0];
            Q3 = [0.125,0.0,0.0,0.14];
            Q4 = [1.0, 0.0, 0.0, 1.0];
            tEnd = 0.20;
            desc = 'Implosion-like case: compression waves focus toward low-pressure quadrant.';

        case 'explosion'
            % 高压象限向外释放
            Q1 = [0.125,0.0,0.0,0.1];
            Q2 = [0.125,0.0,0.0,0.1];
            Q3 = [1.0,  0.0,0.0,1.0];
            Q4 = [0.125,0.0,0.0,0.1];
            tEnd = 0.20;
            desc = 'Explosion-like case: high pressure quadrant expands outward.';

        case 'khlike'
            % 剪切层型设置，用于演示滑移线/接触面卷曲倾向
            Q1 = [1.0,  0.5,  0.0, 1.0];
            Q2 = [2.0, -0.5,  0.0, 1.0];
            Q3 = [1.0,  0.5,  0.0, 1.0];
            Q4 = [2.0, -0.5,  0.0, 1.0];
            tEnd = 0.45;
            desc = 'Kelvin-Helmholtz-like shear/contact demonstration; needs fine grid.';

        otherwise
            error('Unknown caseName: %s',caseName);
    end
end

function [rho,u,v,p] = assignQuadrant(rho,u,v,p,mask,Q)
    rho(mask) = Q(1);
    u(mask)   = Q(2);
    v(mask)   = Q(3);
    p(mask)   = Q(4);
end

function U = prim2cons2D(rho,u,v,p,gamma)
    E = p./((gamma-1)*rho) + 0.5*(u.^2+v.^2);
    U = zeros(size(rho,1),size(rho,2),4);
    U(:,:,1) = rho;
    U(:,:,2) = rho.*u;
    U(:,:,3) = rho.*v;
    U(:,:,4) = rho.*E;
end

function [rho,u,v,p] = cons2prim2D(U,gamma)
    rho = U(:,:,1);
    rho = max(rho,1e-12);

    u = U(:,:,2)./rho;
    v = U(:,:,3)./rho;
    E = U(:,:,4)./rho;

    p = (gamma-1)*rho.*(E-0.5*(u.^2+v.^2));
    p = max(p,1e-12);
end

function F = fluxX(U,gamma)
    [rho,u,v,p] = cons2prim2D(U,gamma);
    E = U(:,:,4)./rho;

    F = zeros(size(U));
    F(:,:,1) = rho.*u;
    F(:,:,2) = rho.*u.^2+p;
    F(:,:,3) = rho.*u.*v;
    F(:,:,4) = u.*(rho.*E+p);
end

function G = fluxY(U,gamma)
    [rho,u,v,p] = cons2prim2D(U,gamma);
    E = U(:,:,4)./rho;

    G = zeros(size(U));
    G(:,:,1) = rho.*v;
    G(:,:,2) = rho.*u.*v;
    G(:,:,3) = rho.*v.^2+p;
    G(:,:,4) = v.*(rho.*E+p);
end

function Unew = stepRusanov2D(U,gamma,dx,dy,dt)
    [Ny,Nx,~] = size(U);

    Fx = fluxX(U,gamma);
    Gy = fluxY(U,gamma);

    % x-direction numerical flux at i+1/2
    FxNum = zeros(Ny,Nx-1,4);
    for i = 1:Nx-1
        UL = U(:,i,:);
        UR = U(:,i+1,:);
        FL = Fx(:,i,:);
        FR = Fx(:,i+1,:);

        [rhoL,uL,vL,pL] = cons2prim2D(UL,gamma);
        [rhoR,uR,vR,pR] = cons2prim2D(UR,gamma);

        aL = sqrt(gamma*pL./rhoL);
        aR = sqrt(gamma*pR./rhoR);
        smax = max(abs(uL)+aL,abs(uR)+aR);

        for k = 1:4
            FxNum(:,i,k) = 0.5*(FL(:,:,k)+FR(:,:,k)) - 0.5*smax.*(UR(:,:,k)-UL(:,:,k));
        end
    end

    % y-direction numerical flux at j+1/2
    GyNum = zeros(Ny-1,Nx,4);
    for j = 1:Ny-1
        UB = U(j,:,:);
        UT = U(j+1,:,:);
        GB = Gy(j,:,:);
        GT = Gy(j+1,:,:);

        [rhoB,uB,vB,pB] = cons2prim2D(UB,gamma);
        [rhoT,uT,vT,pT] = cons2prim2D(UT,gamma);

        aB = sqrt(gamma*pB./rhoB);
        aT = sqrt(gamma*pT./rhoT);
        smax = max(abs(vB)+aB,abs(vT)+aT);

        for k = 1:4
            GyNum(j,:,k) = 0.5*(GB(:,:,k)+GT(:,:,k)) - 0.5*smax.*(UT(:,:,k)-UB(:,:,k));
        end
    end

    Unew = U;
    for k = 1:4
        Unew(2:Ny-1,2:Nx-1,k) = U(2:Ny-1,2:Nx-1,k) ...
            - dt/dx*(FxNum(2:Ny-1,2:Nx-1,k)-FxNum(2:Ny-1,1:Nx-2,k)) ...
            - dt/dy*(GyNum(2:Ny-1,2:Nx-1,k)-GyNum(1:Ny-2,2:Nx-1,k));
    end
end

function U = applyTransmissiveBC(U)
    U(:,1,:)   = U(:,2,:);
    U(:,end,:) = U(:,end-1,:);
    U(1,:,:)   = U(2,:,:);
    U(end,:,:) = U(end-1,:,:);
end

function field = selectField(plotMode,rho,u,v,p,gamma,dx,dy)
    switch lower(plotMode)
        case 'rho'
            field = rho;
        case 'p'
            field = p;
        case 'speed'
            field = sqrt(u.^2+v.^2);
        case 'mach'
            field = sqrt(u.^2+v.^2)./sqrt(gamma*p./rho);
        case 'schlieren'
            field = syntheticSchlieren(rho,dx,dy);
        otherwise
            error('Unknown plotMode.');
    end
end

function s = syntheticSchlieren(rho,dx,dy)
    [drdy,drdx] = gradient(rho,dy,dx);
    gradMag = sqrt(drdx.^2+drdy.^2);
    s = log(1+gradMag);
end

function titleStr = fieldTitle(plotMode)
    switch lower(plotMode)
        case 'rho'
            titleStr = 'Density \rho';
        case 'p'
            titleStr = 'Pressure p';
        case 'speed'
            titleStr = 'Speed |\bfu|';
        case 'mach'
            titleStr = 'Mach number';
        case 'schlieren'
            titleStr = 'Synthetic schlieren log(1+|\nabla \rho|)';
        otherwise
            titleStr = plotMode;
    end
end

function clim = robustCLim(A)
    a = A(:);
    a = a(isfinite(a));
    if isempty(a)
        clim = [0 1];
        return;
    end

    lo = prctileLocal(a,1);
    hi = prctileLocal(a,99);

    if abs(hi-lo) < 1e-12
        pad = 0.1*max(1,abs(hi));
        clim = [lo-pad, hi+pad];
    else
        clim = [lo,hi];
    end
end

function q = prctileLocal(a,p)
    a = sort(a(:));
    n = numel(a);
    if n == 1
        q = a;
        return;
    end
    idx = 1 + (n-1)*p/100;
    i0 = floor(idx);
    i1 = ceil(idx);
    if i0 == i1
        q = a(i0);
    else
        q = a(i0) + (idx-i0)*(a(i1)-a(i0));
    end
end
