% PINN_case03_burgers_1D_traced_fixed.m
% 一维 Burgers 方程 PINN 演示
% u_t + u u_x = nu u_xx
% 采用具有解析解的 traveling-wave case
% u(x,t) = c0 - A*tanh(A*(x-c0*t-x0)/(2*nu))

clear; clc; close all;
rng(3);

nu = 0.02;
c0 = 1.0;
A  = 0.5;
x0 = 0.5;
tMax = 0.4;

numEpochs = 3000;
learnRate = 1e-3;
Nf = 3500;
Ni = 180;
Nb = 180;

x_f = rand(1,Nf);
t_f = tMax*rand(1,Nf);

x_i = linspace(0,1,Ni);
t_i = zeros(1,Ni);
u_i = exactSolution(x_i,t_i,nu,c0,A,x0);

t_b = linspace(0,tMax,Nb);
x_b0 = zeros(1,Nb);
x_b1 = ones(1,Nb);
u_b0 = exactSolution(x_b0,t_b,nu,c0,A,x0);
u_b1 = exactSolution(x_b1,t_b,nu,c0,A,x0);

XT_f  = dlarray(single([x_f;  t_f ]),'CB');
XT_i  = dlarray(single([x_i;  t_i ]),'CB');
U_i   = dlarray(single(u_i),'CB');
XT_b0 = dlarray(single([x_b0; t_b ]),'CB');
XT_b1 = dlarray(single([x_b1; t_b ]),'CB');
U_b0  = dlarray(single(u_b0),'CB');
U_b1  = dlarray(single(u_b1),'CB');

net = createMLP(2,1,50,5);

trailingAvg = [];
trailingAvgSq = [];
lossHist = zeros(numEpochs,1);

figure('Name','Training Loss');
for epoch = 1:numEpochs
    [loss,gradients,lossPDE,lossIC,lossBC] = dlfeval(@modelLoss,net,XT_f,XT_i,U_i,XT_b0,U_b0,XT_b1,U_b1,nu);
    [net,trailingAvg,trailingAvgSq] = adamupdate(net,gradients,trailingAvg,trailingAvgSq,epoch,learnRate);

    lossHist(epoch) = double(gather(extractdata(loss)));
    if mod(epoch,100)==0 || epoch==1
        semilogy(1:epoch,lossHist(1:epoch),'LineWidth',1.5); grid on;
        xlabel('Epoch'); ylabel('Loss'); title('PINN training: Burgers equation'); drawnow;
        fprintf('Epoch %5d | Loss %.3e | PDE %.3e | IC %.3e | BC %.3e\n',epoch,double(gather(extractdata(loss))),double(gather(extractdata(lossPDE))),double(gather(extractdata(lossIC))),double(gather(extractdata(lossBC))));
    end
end

Nx = 180; Nt = 100;
x = linspace(0,1,Nx);
t = linspace(0,tMax,Nt);
[X,T] = meshgrid(x,t);
XT_test = dlarray(single([X(:)'; T(:)']),'CB');
U_pred = forward(net,XT_test);
U_pred = reshape(double(gather(extractdata(U_pred))),Nt,Nx);
U_ex = exactSolution(X,T,nu,c0,A,x0);
err = abs(U_pred-U_ex);

figure('Name','PINN Burgers');
contourf(X,T,U_pred,30,'LineColor','none'); colorbar; xlabel('x'); ylabel('t'); title('PINN solution');
figure('Name','Exact Burgers');
contourf(X,T,U_ex,30,'LineColor','none'); colorbar; xlabel('x'); ylabel('t'); title('Exact solution');
figure('Name','Absolute error');
contourf(X,T,err,30,'LineColor','none'); colorbar; xlabel('x'); ylabel('t'); title('Absolute error');

figure('Name','Line comparison'); hold on; grid on;
for tt = [0,0.1,0.2,0.3,0.4]
    [~,idx] = min(abs(t-tt));
    plot(x,U_ex(idx,:),'--','LineWidth',1.2);
    plot(x,U_pred(idx,:),'-','LineWidth',1.2);
end
xlabel('x'); ylabel('u'); title('Line comparison: dashed exact, solid PINN');

function [loss,gradients,lossPDE,lossIC,lossBC] = modelLoss(net,XT_f,XT_i,U_i,XT_b0,U_b0,XT_b1,U_b1,nu)
    u_f = forward(net,XT_f);

    grad_u = dlgradient(sum(u_f,'all'),XT_f,'EnableHigherDerivatives',true);
    u_x = grad_u(1,:);
    u_t = grad_u(2,:);

    grad_ux = dlgradient(sum(u_x,'all'),XT_f,'EnableHigherDerivatives',true);
    u_xx = grad_ux(1,:);

    r = u_t + u_f.*u_x - nu*u_xx;
    lossPDE = mean(r.^2,'all');

    u_i_pred = forward(net,XT_i);
    lossIC = mean((u_i_pred-U_i).^2,'all');

    u_b0_pred = forward(net,XT_b0);
    u_b1_pred = forward(net,XT_b1);
    lossBC = mean((u_b0_pred-U_b0).^2,'all') + mean((u_b1_pred-U_b1).^2,'all');

    loss = lossPDE + 10*lossIC + lossBC;
    gradients = dlgradient(loss,net.Learnables);
end

function u = exactSolution(x,t,nu,c0,A,x0)
    z = A*(x-c0*t-x0)/(2*nu);
    u = c0 - A*tanh(z);
end

function net = createMLP(numInputs,numOutputs,numNeurons,numHidden)
    layers = [featureInputLayer(numInputs,'Normalization','none','Name','input')];
    for k = 1:numHidden
        layers = [layers
            fullyConnectedLayer(numNeurons,'Name',['fc' num2str(k)])
            tanhLayer('Name',['tanh' num2str(k)])]; %#ok<AGROW>
    end
    layers = [layers
        fullyConnectedLayer(numOutputs,'Name','output')];
    net = dlnetwork(layerGraph(layers));
end
