Saturday, May 16, 2026

Explanation of of Temporal GAT

The sample code is as below 


# ==============================================================
# PRODUCTION-AWARE TEMPORAL GAT
# Includes FULL lifecycle:
#
# Initial Training
# Risk Scores
# Monitoring
# Drift Detection
# Retraining
# Post-Retrain Predictions
#
# Plus Temporal Graph Learning
#
# pip install torch torch-geometric
# ==============================================================

import torch
import torch.nn.functional as F
from torch.nn import Linear, LayerNorm, Dropout, LSTM
from torch_geometric.nn import GATv2Conv
import copy
import random

torch.manual_seed(42)
random.seed(42)

# ==============================================================
# CONFIG
# ==============================================================

TIMESTEPS = 300
SEQ_LEN = 8
EPOCHS = 15

# ==============================================================
# KNOWLEDGE GRAPH
# ==============================================================

services = [
"Frontend",
"Auth",
"Cart",
"Order",
"Payment",
"Inventory",
"Fraud",
"Notification"
]

idx = {n:i for i,n in enumerate(services)}
NODES = len(services)
FEATS = 4

edges = [
("Frontend","Auth"),
("Frontend","Cart"),
("Frontend","Order"),
("Cart","Inventory"),
("Order","Payment"),
("Order","Inventory"),
("Order","Fraud"),
("Order","Notification")
]

edge_index = torch.tensor(
[[idx[s], idx[t]] for s,t in edges],
dtype=torch.long
).t().contiguous()

# ==============================================================
# GENERATE TEMPORAL TELEMETRY
# ==============================================================

base = torch.tensor([
[40,50,120,0.01],
[20,30, 80,0.00],
[50,60,100,0.02],
[70,75,220,0.05],
[60,70,180,0.03],
[45,55,140,0.01],
[30,35,110,0.00],
[25,40, 90,0.00],
], dtype=torch.float)

def generate_data(steps=TIMESTEPS, drift=False):

xs, ys = [], []

for t in range(steps):

x = base + torch.randn(NODES, FEATS) * 3
y = torch.zeros(NODES, dtype=torch.long)

# Normal payment issue
if random.random() < 0.15:
x[idx["Payment"]] += torch.tensor([20,15,180,0.20])
x[idx["Order"]] += torch.tensor([8,8,60,0.08])
x[idx["Frontend"]] += torch.tensor([5,5,30,0.03])

y[idx["Payment"]] = 1
y[idx["Order"]] = 1
y[idx["Frontend"]] = 1

# Drift scenario (worse environment)
if drift and random.random() < 0.25:
x[idx["Cart"]] += torch.tensor([10,8,70,0.07])
y[idx["Cart"]] = 1

xs.append(x)
ys.append(y)

xs = torch.stack(xs)
ys = torch.stack(ys)

return xs, ys

all_x_raw, all_y = generate_data()

# ==============================================================
# FEATURE ENGINEERING + SCALING
# ==============================================================

mean = all_x_raw.mean((0,1))
std = all_x_raw.std((0,1)) + 1e-8

all_x = (all_x_raw - mean) / std

print("Scaled sample values\n", all_x[0])

# ==============================================================
# BUILD TEMPORAL WINDOWS
# ==============================================================

def make_windows(X, Y):

seq_x, seq_y = [], []

for t in range(SEQ_LEN, len(X)):
seq_x.append(X[t-SEQ_LEN:t]) # [S,N,F]
seq_y.append(Y[t]) # [N]

return torch.stack(seq_x), torch.stack(seq_y)

X_all, Y_all = make_windows(all_x, all_y)

split = int(len(X_all)*0.8)

X_train = X_all[:split]
Y_train = Y_all[:split]

X_test = X_all[split:]
Y_test = Y_all[split:]

# ==============================================================
# MODEL
# ==============================================================

class TemporalGAT(torch.nn.Module):

def __init__(self):
super().__init__()

# GAT Layer 1
self.gat1 = GATv2Conv(
in_channels=4,
out_channels=16,
heads=4,
concat=True,
dropout=0.2
)

# GAT Layer 2
self.gat2 = GATv2Conv(
in_channels=64,
out_channels=16,
heads=2,
concat=True,
dropout=0.2
)

self.norm1 = LayerNorm(64)
self.norm2 = LayerNorm(32)

self.res = Linear(64,32)
self.drop = Dropout(0.2)

# Temporal layer
self.lstm = LSTM(
input_size=32,
hidden_size=64,
batch_first=True
)

# Prediction Head
self.head = Linear(64,2)

def graph_encode(self, x):

h1 = self.gat1(x, edge_index)
h1 = F.elu(h1)
h1 = self.norm1(h1)
h1 = self.drop(h1)

h2 = self.gat2(h1, edge_index)
h2 = h2 + self.res(h1)
h2 = F.elu(h2)
h2 = self.norm2(h2)

return h2

def forward(self, X):
print("Calling Input shape:", X)
# X = [B,S,N,F]
B,S,N,Fdim = X.shape

batch_graphs = []

for b in range(B):

seq_emb = []

for t in range(S):
emb = self.graph_encode(X[b,t]) # [N,32]
seq_emb.append(emb)

seq_emb = torch.stack(seq_emb)
batch_graphs.append(seq_emb)

batch_graphs = torch.stack(batch_graphs) # [B,S,N,32]

outputs = []

for node in range(N):
node_seq = batch_graphs[:,:,node,:] # [B,S,32]
out,_ = self.lstm(node_seq)
last = out[:,-1,:]
logits = self.head(last)
outputs.append(logits)

outputs = torch.stack(outputs, dim=1) # [B,N,2]

return outputs

# ==============================================================
# TRAINING
# ==============================================================

def train_model(model, X, Y, epochs=EPOCHS):

optimizer = torch.optim.Adam(
model.parameters(),
lr=0.003,
weight_decay=1e-4
)

best_loss = 1e9
best_state = None

for epoch in range(1, epochs+1):

model.train()
optimizer.zero_grad()

logits = model(X)

loss = F.cross_entropy(
logits.reshape(-1,2),
Y.reshape(-1)
)

loss.backward()
optimizer.step()

if loss.item() < best_loss:
best_loss = loss.item()
best_state = copy.deepcopy(model.state_dict())

if epoch % 3 == 0:

pred = logits.argmax(dim=2)
acc = (pred == Y).float().mean().item()

print(
f"Epoch {epoch:03d} | "
f"Loss {loss.item():.4f} | "
f"Acc {acc:.3f}"
)

model.load_state_dict(best_state)

# ==============================================================
# MONITORING + RETRAINING
# ==============================================================

def monitor_and_retrain(model):

print("\n================================================")
print("MONITORING")
print("================================================")

new_raw, new_y = generate_data(steps=TIMESTEPS, drift=True)

new_scaled = (new_raw - mean) / std

drift_score = torch.mean(
torch.abs(new_scaled - all_x)
).item()

print(f"Feature Drift Score: {drift_score:.4f}")

if drift_score > 0.05:

print("Drift Detected -> Retraining Triggered")

X_new, Y_new = make_windows(new_scaled, new_y)

train_model(model, X_new, Y_new, epochs=8)

return X_new, Y_new

print("No Retraining Needed")

return X_test, Y_test

# ==============================================================
# MAIN
# ==============================================================

model = TemporalGAT()

print("================================================")
print("INITIAL TRAINING")
print("================================================")

train_model(model, X_train, Y_train)

# ==============================================================
# RISK SCORES
# ==============================================================

print("\n================================================")
print("RISK SCORES")
print("================================================")

model.eval()

with torch.no_grad():

sample = X_test[:1]

logits = model(sample)

probs = F.softmax(logits, dim=2)[0,:,1]

for i,name in enumerate(services):
print(f"{name:15s} Risk Score = {probs[i]:.3f}")

# ==============================================================
# MONITORING
# ==============================================================

X_post, Y_post = monitor_and_retrain(model)

# ==============================================================
# POST RETRAIN PREDICTIONS
# ==============================================================

print("\n================================================")
print("POST-RETRAIN PREDICTIONS")
print("================================================")

with torch.no_grad():

sample = X_post[:1]

logits = model(sample)

probs = F.softmax(logits, dim=2)[0,:,1]

for i,name in enumerate(services):

print(f"{name:15s} Risk Score = {probs[i]:.3f}") 


Step-by-Step Flow of the Entire Temporal GAT Training Procedure

The easiest way to understand the training lifecycle is to think of it as a repeated cycle of:

Observe → Predict → Measure Error → Adjust Weights → Repeat

The model gradually learns how operational failures propagate through both:

  • graph structure

  • time evolution


1. Build the Knowledge Graph

The process begins by constructing the dependency graph.

edge_index = torch.tensor(...)

This defines:

Which nodes influence which other nodes

Example:

Frontend → Order
Order → Payment

At this stage, the graph only defines topology.

The model still knows nothing about operational behavior.


2. Generate Temporal Telemetry

Synthetic telemetry is generated:

generate_data()

Each timestamp contains features like:

  • CPU

  • Memory

  • Latency

  • Error rate

Shape:

[Timestamps, Nodes, Features]

Example:

[300, 8, 4]

Meaning:

300 timestamps
8 services
4 operational metrics

3. Inject Failure Propagation Patterns

The simulation intentionally injects cascading failures.

Example:

Payment worsens
↓
Order worsens
↓
Frontend worsens

This teaches the model that:

risk propagates through dependencies

rather than appearing randomly.


4. Feature Scaling

Raw metrics are normalized:

all_x = (all_x_raw - mean) / std

This is critical because:

Latency may be hundreds
Error rate may be decimals

Without scaling:

largest numeric feature dominates learning

instead of most meaningful feature.


5. Build Temporal Windows

The continuous timeline is converted into sequences.

make_windows()

Example:

timestamps 1–8 → predict timestamp 9
timestamps 2–9 → predict timestamp 10

Shape becomes:

[Batch, Sequence, Nodes, Features]

Example:

[233, 8, 8, 4]

Meaning:

DimensionMeaning
233training samples
8time window
8services
4metrics

6. Initialize the Temporal GAT Model

The model architecture is created.

model = TemporalGAT()

At this point:

all neural weights are random

including:

  • graph attention weights

  • projection matrices

  • LSTM parameters

  • prediction head weights

The model has not learned anything yet.


7. Start Training Loop

Training begins:

for epoch in range(epochs):

Each epoch is one complete learning cycle.


8. Forward Pass Begins

This line triggers everything:

logits = model(X)

Internally:

model.__call__(X)
↓
forward(X)

Now the full Temporal GAT pipeline executes.


9. Forward Step — Process Temporal Batch

Inside forward():

B,S,N,F = X.shape

Example:

[233, 8, 8, 4]

Meaning:

233 training sequences
8 timestamps each
8 services
4 features

10. Graph Encoding Per Timestamp

For every timestamp:

emb = self.graph_encode(X[b,t])

This processes:

ONE graph snapshot

through GAT.


11. GAT Layer 1 Executes

h1 = self.gat1(x, edge_index)

Now the model learns:

Which neighboring services matter most

using attention.

This stage performs:

  • linear projection

  • attention computation

  • softmax normalization

  • weighted aggregation

internally.


12. Activation + Stabilization

h1 = F.elu(h1)
h1 = self.norm1(h1)
h1 = self.drop(h1)

These steps:

LayerPurpose
ELUnon-linearity
LayerNormstabilize embeddings
Dropoutprevent overfitting

13. GAT Layer 2 Executes

h2 = self.gat2(h1, edge_index)

Now deeper relationship propagation occurs.

Meaning:

multi-hop influence

can emerge.

Example:

Payment affects Order
Order affects Frontend

14. Residual Connection Applied

h2 = h2 + self.res(h1)

This prevents:

oversmoothing

where all node embeddings become too similar.


15. Graph Embeddings Produced

Final graph output:

[Nodes, 32]

Each node now has:

graph-aware embedding

representing:

  • local telemetry

  • dependency influence

  • propagated risk


16. Temporal Sequence Construction

For each node:

node_seq = batch_graphs[:,:,node,:]

This extracts:

node history across time

Shape:

[Batch, Sequence, Embedding]

17. LSTM Processes Temporal Evolution

out,_ = self.lstm(node_seq)

Now the model learns:

  • degradation trends

  • recurring instability

  • progressive failure buildup

  • temporal propagation

The GAT learned:

spatial relationships

The LSTM learns:

temporal behavior

18. Final Temporal State Selected

last = out[:,-1,:]

This captures:

summary of recent operational history

19. Prediction Head Executes

logits = self.head(last)

The model now predicts:

Risky
vs
Not Risky

for each node.


20. Forward Pass Ends

Final output shape:

[Batch, Nodes, Classes]

Example:

[233, 8, 2]

21. Loss Calculation

Training compares predictions against truth:

loss = F.cross_entropy(...)

This computes:

How wrong the model currently is

22. Backpropagation Begins

loss.backward()

PyTorch now automatically computes gradients for:

  • GAT attention weights

  • projection matrices

  • LSTM weights

  • prediction head

This is the true learning stage.


23. Weight Update

optimizer.step()

Weights slightly adjust.

Meaning:

attention becomes more meaningful
predictions become less wrong

24. Entire Forward Pass Repeats

Training loops again:

forward
→ loss
→ gradients
→ update
→ repeat

This repeated correction is how the network gradually learns operational behavior.


25. Best Model Saved

best_state = copy.deepcopy(...)

The best-performing weights are preserved.

Important because graph training can become unstable after temporarily improving.


26. Inference / Risk Prediction

After training:

model.eval()

Now the model performs prediction only.

No weights change.


27. Softmax Converts Logits to Risk Probabilities

probs = F.softmax(logits, dim=2)

Now predictions become:

0.93 risk probability
0.12 risk probability

instead of raw logits.


28. Monitoring Phase Begins

New telemetry arrives.

The system checks:

drift_score = ...

to determine whether the environment changed significantly.


29. Drift Detection

If operational behavior changes:

new traffic patterns
new failures
new dependencies

the model may become stale.


30. Retraining Triggered

If drift exceeds threshold:

train_model(...)

runs again using newer data.

This is critical because:

graph intelligence decays over time

if operational systems evolve.


Final Conceptual Flow

Telemetry
    ↓
Graph Attention
    ↓
Dependency-Aware Embeddings
    ↓
Temporal Learning
    ↓
Risk Prediction
    ↓
Loss Calculation
    ↓
Backpropagation
    ↓
Weight Updates
    ↓
Monitoring
    ↓
Retraining

Most Important Insight

The model is not simply learning:

“Which node is unhealthy?”

It is gradually learning:

“How operational degradation propagates
through connected systems over time.”

That is the real significance of Temporal Graph Learning.

No comments:

Post a Comment