% PINN_case02_heat_1D_traced_fixed.m
% 一维热传导方程 PINN 演示
% T_t = alpha T_xx, x in [0,1], t in [0,1]
% T(0,t)=T(1,t)=0, T(x,0)=sin(pi*x)
% 精确解: T(x,t)=sin(pi*x)*exp(-alpha*pi^2*t)

clear; clc; close all;
rng(2);

alpha = 0.1;
numEpochs = 2500;
learnRate = 1e-3;
Nf = 3000;
Ni = 160;
Nb = 160;

x_f = rand(1,Nf);
t_f = rand(1,Nf);

x_i = linspace(0,1,Ni);
t_i = zeros(1,Ni);
T_i = exactSolution(x_i,t_i,alpha);

t_b = linspace(0,1,Nb);
x_b0 = zeros(1,Nb);
x_b1 = ones(1,Nb);
T_b0 = zeros(1,Nb);
T_b1 = zeros(1,Nb);

XT_f  = dlarray(single([x_f;  t_f ]),'CB');
XT_i  = dlarray(single([x_i;  t_i ]),'CB');
T_i_dl = dlarray(single(T_i),'CB');
XT_b0 = dlarray(single([x_b0; t_b ]),'CB');
XT_b1 = dlarray(single([x_b1; t_b ]),'CB');
T_b0_dl = dlarray(single(T_b0),'CB');
T_b1_dl = dlarray(single(T_b1),'CB');

net = createMLP(2,1,40,4);

trailingAvg = [];
trailingAvgSq = [];
lossHist = zeros(numEpochs,1);

figure('Name','Training Loss');
for epoch = 1:numEpochs
    [loss,gradients,lossPDE,lossIC,lossBC] = dlfeval(@modelLoss,net,XT_f,XT_i,T_i_dl,XT_b0,T_b0_dl,XT_b1,T_b1_dl,alpha);
    [net,trailingAvg,trailingAvgSq] = adamupdate(net,gradients,trailingAvg,trailingAvgSq,epoch,learnRate);

    lossHist(epoch) = double(gather(extractdata(loss)));
    if mod(epoch,100)==0 || epoch==1
        semilogy(1:epoch,lossHist(1:epoch),'LineWidth',1.5); grid on;
        xlabel('Epoch'); ylabel('Loss'); title('PINN training: 1D heat equation'); drawnow;
        fprintf('Epoch %5d | Loss %.3e | PDE %.3e | IC %.3e | BC %.3e\n',epoch,double(gather(extractdata(loss))),double(gather(extractdata(lossPDE))),double(gather(extractdata(lossIC))),double(gather(extractdata(lossBC))));
    end
end

Nx = 160; Nt = 100;
x = linspace(0,1,Nx);
t = linspace(0,1,Nt);
[X,T] = meshgrid(x,t);
XT_test = dlarray(single([X(:)'; T(:)']),'CB');
T_pred = forward(net,XT_test);
T_pred = reshape(double(gather(extractdata(T_pred))),Nt,Nx);
T_ex = exactSolution(X,T,alpha);
err = abs(T_pred-T_ex);

figure('Name','PINN temperature');
contourf(X,T,T_pred,30,'LineColor','none'); colorbar; xlabel('x'); ylabel('t'); title('PINN solution');
figure('Name','Exact temperature');
contourf(X,T,T_ex,30,'LineColor','none'); colorbar; xlabel('x'); ylabel('t'); title('Exact solution');
figure('Name','Absolute error');
contourf(X,T,err,30,'LineColor','none'); colorbar; xlabel('x'); ylabel('t'); title('Absolute error');

figure('Name','Line comparison'); hold on; grid on;
for tt = [0,0.25,0.5,0.75,1.0]
    [~,idx] = min(abs(t-tt));
    plot(x,T_ex(idx,:),'--','LineWidth',1.2);
    plot(x,T_pred(idx,:),'-','LineWidth',1.2);
end
xlabel('x'); ylabel('T'); title('Line comparison: dashed exact, solid PINN');

function [loss,gradients,lossPDE,lossIC,lossBC] = modelLoss(net,XT_f,XT_i,T_i,XT_b0,T_b0,XT_b1,T_b1,alpha)
    T_f = forward(net,XT_f);

    grad_T = dlgradient(sum(T_f,'all'),XT_f,'EnableHigherDerivatives',true);
    T_x = grad_T(1,:);
    T_t = grad_T(2,:);

    grad_Tx = dlgradient(sum(T_x,'all'),XT_f,'EnableHigherDerivatives',true);
    T_xx = grad_Tx(1,:);

    r = T_t - alpha*T_xx;
    lossPDE = mean(r.^2,'all');

    T_i_pred = forward(net,XT_i);
    lossIC = mean((T_i_pred-T_i).^2,'all');

    T_b0_pred = forward(net,XT_b0);
    T_b1_pred = forward(net,XT_b1);
    lossBC = mean((T_b0_pred-T_b0).^2,'all') + mean((T_b1_pred-T_b1).^2,'all');

    loss = lossPDE + 10*lossIC + lossBC;
    gradients = dlgradient(loss,net.Learnables);
end

function T = exactSolution(x,t,alpha)
    T = sin(pi*x).*exp(-alpha*pi^2*t);
end

function net = createMLP(numInputs,numOutputs,numNeurons,numHidden)
    layers = [featureInputLayer(numInputs,'Normalization','none','Name','input')];
    for k = 1:numHidden
        layers = [layers
            fullyConnectedLayer(numNeurons,'Name',['fc' num2str(k)])
            tanhLayer('Name',['tanh' num2str(k)])]; %#ok<AGROW>
    end
    layers = [layers
        fullyConnectedLayer(numOutputs,'Name','output')];
    net = dlnetwork(layerGraph(layers));
end
