% PINN_case01_advection_1D_traced_fixed.m
% 一维线性对流方程 PINN
% u_t + c u_x = 0, x in [0,1], t in [0,1]
% 周期边界: u(0,t)=u(1,t)
% 精确解: u(x,t)=sin(2*pi*(x-c*t))


clear; clc; close all; rng(1);

% 参数
c = 1.0;
numEpochs = 2000;
learnRate = 1e-3;

Nf = 2500;      % PDE 配点
Ni = 160;       % 初始点
Nb = 160;       % 边界点

% 采样点
x_f = rand(1,Nf);
t_f = rand(1,Nf);

x_i = linspace(0,1,Ni);
t_i = zeros(1,Ni);
u_i = exactSolution(x_i,t_i,c);

t_b = linspace(0,1,Nb);
x_b0 = zeros(1,Nb);
x_b1 = ones(1,Nb);

% 在主程序中创建 dlarray，并直接传入 dlfeval
XT_f  = dlarray(single([x_f;  t_f ]),'CB');
XT_i  = dlarray(single([x_i;  t_i ]),'CB');
U_i   = dlarray(single(u_i),'CB');
XT_b0 = dlarray(single([x_b0; t_b ]),'CB');
XT_b1 = dlarray(single([x_b1; t_b ]),'CB');

% 网络
net = createMLP(2,1,40,4);

% 训练
trailingAvg = [];
trailingAvgSq = [];
lossHist = zeros(numEpochs,1);
pdeHist  = zeros(numEpochs,1);
icHist   = zeros(numEpochs,1);
bcHist   = zeros(numEpochs,1);

figure('Name','Training Loss');
for epoch = 1:numEpochs
    [loss,gradients,lossPDE,lossIC,lossBC] = dlfeval(@modelLoss,net,XT_f,XT_i,U_i,XT_b0,XT_b1,c);
    [net,trailingAvg,trailingAvgSq] = adamupdate(net,gradients,trailingAvg,trailingAvgSq,epoch,learnRate);

    lossHist(epoch) = double(gather(extractdata(loss)));
    pdeHist(epoch)  = double(gather(extractdata(lossPDE)));
    icHist(epoch)   = double(gather(extractdata(lossIC)));
    bcHist(epoch)   = double(gather(extractdata(lossBC)));

    if mod(epoch,100)==0 || epoch==1
        semilogy(1:epoch,lossHist(1:epoch),'LineWidth',1.5); grid on;
        xlabel('Epoch'); ylabel('Loss'); title('PINN training: 1D advection'); drawnow;
        fprintf('Epoch %5d | Loss %.3e | PDE %.3e | IC %.3e | BC %.3e\n',epoch,lossHist(epoch),pdeHist(epoch),icHist(epoch),bcHist(epoch));
    end
end

% 预测与精确解对比
Nx = 160; Nt = 100;
x = linspace(0,1,Nx);
t = linspace(0,1,Nt);
[X,T] = meshgrid(x,t);
XT_test = dlarray(single([X(:)'; T(:)']),'CB');
U_pred = forward(net,XT_test);
U_pred = reshape(double(gather(extractdata(U_pred))),Nt,Nx);
U_ex = exactSolution(X,T,c);
err = abs(U_pred-U_ex);

figure('Name','PINN vs exact');
contourf(X,T,U_pred,30,'LineColor','none'); colorbar;
xlabel('x'); ylabel('t'); title('PINN solution');

figure('Name','Exact solution');
contourf(X,T,U_ex,30,'LineColor','none'); colorbar;
xlabel('x'); ylabel('t'); title('Exact solution');

figure('Name','Absolute error');
contourf(X,T,err,30,'LineColor','none'); colorbar;
xlabel('x'); ylabel('t'); title('Absolute error');

figure('Name','Line comparison');
hold on; grid on;
for tt = [0,0.25,0.5,0.75,1.0]
    [~,idx] = min(abs(t-tt));
    plot(x,U_ex(idx,:),'--','LineWidth',1.2);
    plot(x,U_pred(idx,:),'-','LineWidth',1.2);
end
xlabel('x'); ylabel('u'); title('Line comparison: dashed exact, solid PINN');

%% 局部函数
function [loss,gradients,lossPDE,lossIC,lossBC] = modelLoss(net,XT_f,XT_i,U_i,XT_b0,XT_b1,c)
    u_f = forward(net,XT_f);

    grad_u = dlgradient(sum(u_f,'all'),XT_f,'EnableHigherDerivatives',true);
    u_x = grad_u(1,:);
    u_t = grad_u(2,:);

    r = u_t + c*u_x;
    lossPDE = mean(r.^2,'all');

    u_i_pred = forward(net,XT_i);
    lossIC = mean((u_i_pred-U_i).^2,'all');

    u_b0 = forward(net,XT_b0);
    u_b1 = forward(net,XT_b1);
    lossBC = mean((u_b0-u_b1).^2,'all');

    loss = lossPDE + 10*lossIC + lossBC;
    gradients = dlgradient(loss,net.Learnables);
end

function u = exactSolution(x,t,c)
    u = sin(2*pi*(x-c*t));
end

function net = createMLP(numInputs,numOutputs,numNeurons,numHidden)
    layers = [featureInputLayer(numInputs,'Normalization','none','Name','input')];
    for k = 1:numHidden
        layers = [layers
            fullyConnectedLayer(numNeurons,'Name',['fc' num2str(k)])
            tanhLayer('Name',['tanh' num2str(k)])]; %#ok<AGROW>
    end
    layers = [layers
        fullyConnectedLayer(numOutputs,'Name','output')];
    net = dlnetwork(layerGraph(layers));
end
