% PINN_case04_NS_TaylorGreen_2D_traced_fixed.m
% 二维不可压 Navier-Stokes 方程 PINN 演示：Taylor-Green vortex
% 输入: (x,y,t), 输出: (u,v,p)
% 周期区域 [0,2*pi] x [0,2*pi], t in [0,tMax]
% 精确解:
% u = -cos(x)*sin(y)*exp(-2*nu*t)
% v =  sin(x)*cos(y)*exp(-2*nu*t)
% p = -0.25*(cos(2x)+cos(2y))*exp(-4*nu*t)
%
% 说明：这个 case 计算量明显大于一维问题，课堂演示可先减小 numEpochs 和 Nf。

clear; clc; close all;
rng(4);

nu = 0.01;
tMax = 0.5;
numEpochs = 2500;
learnRate = 8e-4;
Nf = 4000;
Ni = 1200;

% PDE 配点
x_f = 2*pi*rand(1,Nf);
y_f = 2*pi*rand(1,Nf);
t_f = tMax*rand(1,Nf);

% 初始点
x_i = 2*pi*rand(1,Ni);
y_i = 2*pi*rand(1,Ni);
t_i = zeros(1,Ni);
[u_i,v_i,p_i] = exactSolution(x_i,y_i,t_i,nu);

XYT_f = dlarray(single([x_f; y_f; t_f]),'CB');
XYT_i = dlarray(single([x_i; y_i; t_i]),'CB');
UVP_i = dlarray(single([u_i; v_i; p_i]),'CB');

net = createMLP(3,3,60,5);

trailingAvg = [];
trailingAvgSq = [];
lossHist = zeros(numEpochs,1);

figure('Name','Training Loss');
for epoch = 1:numEpochs
    [loss,gradients,lossPDE,lossIC] = dlfeval(@modelLoss,net,XYT_f,XYT_i,UVP_i,nu);
    [net,trailingAvg,trailingAvgSq] = adamupdate(net,gradients,trailingAvg,trailingAvgSq,epoch,learnRate);

    lossHist(epoch) = double(gather(extractdata(loss)));
    if mod(epoch,100)==0 || epoch==1
        semilogy(1:epoch,lossHist(1:epoch),'LineWidth',1.5); grid on;
        xlabel('Epoch'); ylabel('Loss'); title('PINN training: Taylor-Green vortex'); drawnow;
        fprintf('Epoch %5d | Loss %.3e | PDE %.3e | IC %.3e\n',epoch,double(gather(extractdata(loss))),double(gather(extractdata(lossPDE))),double(gather(extractdata(lossIC))));
    end
end

% 在 t=tMax 切片上对比
Nx = 80; Ny = 80;
x = linspace(0,2*pi,Nx);
y = linspace(0,2*pi,Ny);
[X,Y] = meshgrid(x,y);
T = tMax*ones(size(X));
XYT_test = dlarray(single([X(:)'; Y(:)'; T(:)']),'CB');
UVP_pred = forward(net,XYT_test);
UVP_pred = double(gather(extractdata(UVP_pred)));
U_pred = reshape(UVP_pred(1,:),Ny,Nx);
V_pred = reshape(UVP_pred(2,:),Ny,Nx);
P_pred = reshape(UVP_pred(3,:),Ny,Nx);

[U_ex,V_ex,P_ex] = exactSolution(X,Y,T,nu);

figure('Name','u PINN'); contourf(X,Y,U_pred,30,'LineColor','none'); colorbar; axis equal tight; title('u: PINN'); xlabel('x'); ylabel('y');
figure('Name','u exact'); contourf(X,Y,U_ex,30,'LineColor','none'); colorbar; axis equal tight; title('u: exact'); xlabel('x'); ylabel('y');
figure('Name','u error'); contourf(X,Y,abs(U_pred-U_ex),30,'LineColor','none'); colorbar; axis equal tight; title('|u_{PINN}-u_{exact}|'); xlabel('x'); ylabel('y');

figure('Name','v PINN'); contourf(X,Y,V_pred,30,'LineColor','none'); colorbar; axis equal tight; title('v: PINN'); xlabel('x'); ylabel('y');
figure('Name','p PINN'); contourf(X,Y,P_pred,30,'LineColor','none'); colorbar; axis equal tight; title('p: PINN'); xlabel('x'); ylabel('y');

function [loss,gradients,lossPDE,lossIC] = modelLoss(net,XYT_f,XYT_i,UVP_i,nu)
    UVP = forward(net,XYT_f);
    u = UVP(1,:);
    v = UVP(2,:);
    p = UVP(3,:);

    grad_u = dlgradient(sum(u,'all'),XYT_f,'EnableHigherDerivatives',true);
    u_x = grad_u(1,:); u_y = grad_u(2,:); u_t = grad_u(3,:);

    grad_v = dlgradient(sum(v,'all'),XYT_f,'EnableHigherDerivatives',true);
    v_x = grad_v(1,:); v_y = grad_v(2,:); v_t = grad_v(3,:);

    grad_p = dlgradient(sum(p,'all'),XYT_f,'EnableHigherDerivatives',true);
    p_x = grad_p(1,:); p_y = grad_p(2,:);

    grad_ux = dlgradient(sum(u_x,'all'),XYT_f,'EnableHigherDerivatives',true);
    grad_uy = dlgradient(sum(u_y,'all'),XYT_f,'EnableHigherDerivatives',true);
    u_xx = grad_ux(1,:); u_yy = grad_uy(2,:);

    grad_vx = dlgradient(sum(v_x,'all'),XYT_f,'EnableHigherDerivatives',true);
    grad_vy = dlgradient(sum(v_y,'all'),XYT_f,'EnableHigherDerivatives',true);
    v_xx = grad_vx(1,:); v_yy = grad_vy(2,:);

    r_cont = u_x + v_y;
    r_momx = u_t + u.*u_x + v.*u_y + p_x - nu*(u_xx+u_yy);
    r_momy = v_t + u.*v_x + v.*v_y + p_y - nu*(v_xx+v_yy);

    lossPDE = mean(r_cont.^2,'all') + mean(r_momx.^2,'all') + mean(r_momy.^2,'all');

    UVP_i_pred = forward(net,XYT_i);
    lossIC = mean((UVP_i_pred-UVP_i).^2,'all');

    loss = lossPDE + 10*lossIC;
    gradients = dlgradient(loss,net.Learnables);
end

function [u,v,p] = exactSolution(x,y,t,nu)
    u = -cos(x).*sin(y).*exp(-2*nu*t);
    v =  sin(x).*cos(y).*exp(-2*nu*t);
    p = -0.25*(cos(2*x)+cos(2*y)).*exp(-4*nu*t);
end

function net = createMLP(numInputs,numOutputs,numNeurons,numHidden)
    layers = [featureInputLayer(numInputs,'Normalization','none','Name','input')];
    for k = 1:numHidden
        layers = [layers
            fullyConnectedLayer(numNeurons,'Name',['fc' num2str(k)])
            tanhLayer('Name',['tanh' num2str(k)])]; %#ok<AGROW>
    end
    layers = [layers
        fullyConnectedLayer(numOutputs,'Name','output')];
    net = dlnetwork(layerGraph(layers));
end
