The sample code is as below
# ==============================================================
# PRODUCTION-AWARE TEMPORAL GAT
# Includes FULL lifecycle:
#
# Initial Training
# Risk Scores
# Monitoring
# Drift Detection
# Retraining
# Post-Retrain Predictions
#
# Plus Temporal Graph Learning
#
# pip install torch torch-geometric
# ==============================================================
import torch
import torch.nn.functional as F
from torch.nn import Linear, LayerNorm, Dropout, LSTM
from torch_geometric.nn import GATv2Conv
import copy
import random
torch.manual_seed(42)
random.seed(42)
# ==============================================================
# CONFIG
# ==============================================================
TIMESTEPS = 300
SEQ_LEN = 8
EPOCHS = 15
# ==============================================================
# KNOWLEDGE GRAPH
# ==============================================================
services = [
"Frontend",
"Auth",
"Cart",
"Order",
"Payment",
"Inventory",
"Fraud",
"Notification"
]
idx = {n:i for i,n in enumerate(services)}
NODES = len(services)
FEATS = 4
edges = [
("Frontend","Auth"),
("Frontend","Cart"),
("Frontend","Order"),
("Cart","Inventory"),
("Order","Payment"),
("Order","Inventory"),
("Order","Fraud"),
("Order","Notification")
]
edge_index = torch.tensor(
[[idx[s], idx[t]] for s,t in edges],
dtype=torch.long
).t().contiguous()
# ==============================================================
# GENERATE TEMPORAL TELEMETRY
# ==============================================================
base = torch.tensor([
[40,50,120,0.01],
[20,30, 80,0.00],
[50,60,100,0.02],
[70,75,220,0.05],
[60,70,180,0.03],
[45,55,140,0.01],
[30,35,110,0.00],
[25,40, 90,0.00],
], dtype=torch.float)
def generate_data(steps=TIMESTEPS, drift=False):
xs, ys = [], []
for t in range(steps):
x = base + torch.randn(NODES, FEATS) * 3
y = torch.zeros(NODES, dtype=torch.long)
# Normal payment issue
if random.random() < 0.15:
x[idx["Payment"]] += torch.tensor([20,15,180,0.20])
x[idx["Order"]] += torch.tensor([8,8,60,0.08])
x[idx["Frontend"]] += torch.tensor([5,5,30,0.03])
y[idx["Payment"]] = 1
y[idx["Order"]] = 1
y[idx["Frontend"]] = 1
# Drift scenario (worse environment)
if drift and random.random() < 0.25:
x[idx["Cart"]] += torch.tensor([10,8,70,0.07])
y[idx["Cart"]] = 1
xs.append(x)
ys.append(y)
xs = torch.stack(xs)
ys = torch.stack(ys)
return xs, ys
all_x_raw, all_y = generate_data()
# ==============================================================
# FEATURE ENGINEERING + SCALING
# ==============================================================
mean = all_x_raw.mean((0,1))
std = all_x_raw.std((0,1)) + 1e-8
all_x = (all_x_raw - mean) / std
print("Scaled sample values\n", all_x[0])
# ==============================================================
# BUILD TEMPORAL WINDOWS
# ==============================================================
def make_windows(X, Y):
seq_x, seq_y = [], []
for t in range(SEQ_LEN, len(X)):
seq_x.append(X[t-SEQ_LEN:t]) # [S,N,F]
seq_y.append(Y[t]) # [N]
return torch.stack(seq_x), torch.stack(seq_y)
X_all, Y_all = make_windows(all_x, all_y)
split = int(len(X_all)*0.8)
X_train = X_all[:split]
Y_train = Y_all[:split]
X_test = X_all[split:]
Y_test = Y_all[split:]
# ==============================================================
# MODEL
# ==============================================================
class TemporalGAT(torch.nn.Module):
def __init__(self):
super().__init__()
# GAT Layer 1
self.gat1 = GATv2Conv(
in_channels=4,
out_channels=16,
heads=4,
concat=True,
dropout=0.2
)
# GAT Layer 2
self.gat2 = GATv2Conv(
in_channels=64,
out_channels=16,
heads=2,
concat=True,
dropout=0.2
)
self.norm1 = LayerNorm(64)
self.norm2 = LayerNorm(32)
self.res = Linear(64,32)
self.drop = Dropout(0.2)
# Temporal layer
self.lstm = LSTM(
input_size=32,
hidden_size=64,
batch_first=True
)
# Prediction Head
self.head = Linear(64,2)
def graph_encode(self, x):
h1 = self.gat1(x, edge_index)
h1 = F.elu(h1)
h1 = self.norm1(h1)
h1 = self.drop(h1)
h2 = self.gat2(h1, edge_index)
h2 = h2 + self.res(h1)
h2 = F.elu(h2)
h2 = self.norm2(h2)
return h2
def forward(self, X):
print("Calling Input shape:", X)
# X = [B,S,N,F]
B,S,N,Fdim = X.shape
batch_graphs = []
for b in range(B):
seq_emb = []
for t in range(S):
emb = self.graph_encode(X[b,t]) # [N,32]
seq_emb.append(emb)
seq_emb = torch.stack(seq_emb)
batch_graphs.append(seq_emb)
batch_graphs = torch.stack(batch_graphs) # [B,S,N,32]
outputs = []
for node in range(N):
node_seq = batch_graphs[:,:,node,:] # [B,S,32]
out,_ = self.lstm(node_seq)
last = out[:,-1,:]
logits = self.head(last)
outputs.append(logits)
outputs = torch.stack(outputs, dim=1) # [B,N,2]
return outputs
# ==============================================================
# TRAINING
# ==============================================================
def train_model(model, X, Y, epochs=EPOCHS):
optimizer = torch.optim.Adam(
model.parameters(),
lr=0.003,
weight_decay=1e-4
)
best_loss = 1e9
best_state = None
for epoch in range(1, epochs+1):
model.train()
optimizer.zero_grad()
logits = model(X)
loss = F.cross_entropy(
logits.reshape(-1,2),
Y.reshape(-1)
)
loss.backward()
optimizer.step()
if loss.item() < best_loss:
best_loss = loss.item()
best_state = copy.deepcopy(model.state_dict())
if epoch % 3 == 0:
pred = logits.argmax(dim=2)
acc = (pred == Y).float().mean().item()
print(
f"Epoch {epoch:03d} | "
f"Loss {loss.item():.4f} | "
f"Acc {acc:.3f}"
)
model.load_state_dict(best_state)
# ==============================================================
# MONITORING + RETRAINING
# ==============================================================
def monitor_and_retrain(model):
print("\n================================================")
print("MONITORING")
print("================================================")
new_raw, new_y = generate_data(steps=TIMESTEPS, drift=True)
new_scaled = (new_raw - mean) / std
drift_score = torch.mean(
torch.abs(new_scaled - all_x)
).item()
print(f"Feature Drift Score: {drift_score:.4f}")
if drift_score > 0.05:
print("Drift Detected -> Retraining Triggered")
X_new, Y_new = make_windows(new_scaled, new_y)
train_model(model, X_new, Y_new, epochs=8)
return X_new, Y_new
print("No Retraining Needed")
return X_test, Y_test
# ==============================================================
# MAIN
# ==============================================================
model = TemporalGAT()
print("================================================")
print("INITIAL TRAINING")
print("================================================")
train_model(model, X_train, Y_train)
# ==============================================================
# RISK SCORES
# ==============================================================
print("\n================================================")
print("RISK SCORES")
print("================================================")
model.eval()
with torch.no_grad():
sample = X_test[:1]
logits = model(sample)
probs = F.softmax(logits, dim=2)[0,:,1]
for i,name in enumerate(services):
print(f"{name:15s} Risk Score = {probs[i]:.3f}")
# ==============================================================
# MONITORING
# ==============================================================
X_post, Y_post = monitor_and_retrain(model)
# ==============================================================
# POST RETRAIN PREDICTIONS
# ==============================================================
print("\n================================================")
print("POST-RETRAIN PREDICTIONS")
print("================================================")
with torch.no_grad():
sample = X_post[:1]
logits = model(sample)
probs = F.softmax(logits, dim=2)[0,:,1]
for i,name in enumerate(services):
print(f"{name:15s} Risk Score = {probs[i]:.3f}")
Step-by-Step Flow of the Entire Temporal GAT Training Procedure
The easiest way to understand the training lifecycle is to think of it as a repeated cycle of:
Observe → Predict → Measure Error → Adjust Weights → Repeat
The model gradually learns how operational failures propagate through both:
graph structure
time evolution
1. Build the Knowledge Graph
The process begins by constructing the dependency graph.
edge_index = torch.tensor(...)
This defines:
Which nodes influence which other nodes
Example:
Frontend → Order
Order → Payment
At this stage, the graph only defines topology.
The model still knows nothing about operational behavior.
2. Generate Temporal Telemetry
Synthetic telemetry is generated:
generate_data()
Each timestamp contains features like:
CPU
Memory
Latency
Error rate
Shape:
[Timestamps, Nodes, Features]
Example:
[300, 8, 4]
Meaning:
300 timestamps
8 services
4 operational metrics
3. Inject Failure Propagation Patterns
The simulation intentionally injects cascading failures.
Example:
Payment worsens
↓
Order worsens
↓
Frontend worsens
This teaches the model that:
risk propagates through dependencies
rather than appearing randomly.
4. Feature Scaling
Raw metrics are normalized:
all_x = (all_x_raw - mean) / std
This is critical because:
Latency may be hundreds
Error rate may be decimals
Without scaling:
largest numeric feature dominates learning
instead of most meaningful feature.
5. Build Temporal Windows
The continuous timeline is converted into sequences.
make_windows()
Example:
timestamps 1–8 → predict timestamp 9
timestamps 2–9 → predict timestamp 10
Shape becomes:
[Batch, Sequence, Nodes, Features]
Example:
[233, 8, 8, 4]
Meaning:
| Dimension | Meaning |
|---|
| 233 | training samples |
| 8 | time window |
| 8 | services |
| 4 | metrics |
6. Initialize the Temporal GAT Model
The model architecture is created.
model = TemporalGAT()
At this point:
all neural weights are random
including:
graph attention weights
projection matrices
LSTM parameters
prediction head weights
The model has not learned anything yet.
7. Start Training Loop
Training begins:
for epoch in range(epochs):
Each epoch is one complete learning cycle.
8. Forward Pass Begins
This line triggers everything:
logits = model(X)
Internally:
model.__call__(X)
↓
forward(X)
Now the full Temporal GAT pipeline executes.
9. Forward Step — Process Temporal Batch
Inside forward():
B,S,N,F = X.shape
Example:
[233, 8, 8, 4]
Meaning:
233 training sequences
8 timestamps each
8 services
4 features
10. Graph Encoding Per Timestamp
For every timestamp:
emb = self.graph_encode(X[b,t])
This processes:
ONE graph snapshot
through GAT.
11. GAT Layer 1 Executes
h1 = self.gat1(x, edge_index)
Now the model learns:
Which neighboring services matter most
using attention.
This stage performs:
linear projection
attention computation
softmax normalization
weighted aggregation
internally.
12. Activation + Stabilization
h1 = F.elu(h1)
h1 = self.norm1(h1)
h1 = self.drop(h1)
These steps:
| Layer | Purpose |
|---|
| ELU | non-linearity |
| LayerNorm | stabilize embeddings |
| Dropout | prevent overfitting |
13. GAT Layer 2 Executes
h2 = self.gat2(h1, edge_index)
Now deeper relationship propagation occurs.
Meaning:
multi-hop influence
can emerge.
Example:
Payment affects Order
Order affects Frontend
14. Residual Connection Applied
h2 = h2 + self.res(h1)
This prevents:
oversmoothing
where all node embeddings become too similar.
15. Graph Embeddings Produced
Final graph output:
[Nodes, 32]
Each node now has:
graph-aware embedding
representing:
local telemetry
dependency influence
propagated risk
16. Temporal Sequence Construction
For each node:
node_seq = batch_graphs[:,:,node,:]
This extracts:
node history across time
Shape:
[Batch, Sequence, Embedding]
17. LSTM Processes Temporal Evolution
out,_ = self.lstm(node_seq)
Now the model learns:
The GAT learned:
spatial relationships
The LSTM learns:
temporal behavior
18. Final Temporal State Selected
last = out[:,-1,:]
This captures:
summary of recent operational history
19. Prediction Head Executes
logits = self.head(last)
The model now predicts:
Risky
vs
Not Risky
for each node.
20. Forward Pass Ends
Final output shape:
[Batch, Nodes, Classes]
Example:
[233, 8, 2]
21. Loss Calculation
Training compares predictions against truth:
loss = F.cross_entropy(...)
This computes:
How wrong the model currently is
22. Backpropagation Begins
loss.backward()
PyTorch now automatically computes gradients for:
GAT attention weights
projection matrices
LSTM weights
prediction head
This is the true learning stage.
23. Weight Update
optimizer.step()
Weights slightly adjust.
Meaning:
attention becomes more meaningful
predictions become less wrong
24. Entire Forward Pass Repeats
Training loops again:
forward
→ loss
→ gradients
→ update
→ repeat
This repeated correction is how the network gradually learns operational behavior.
25. Best Model Saved
best_state = copy.deepcopy(...)
The best-performing weights are preserved.
Important because graph training can become unstable after temporarily improving.
26. Inference / Risk Prediction
After training:
model.eval()
Now the model performs prediction only.
No weights change.
27. Softmax Converts Logits to Risk Probabilities
probs = F.softmax(logits, dim=2)
Now predictions become:
0.93 risk probability
0.12 risk probability
instead of raw logits.
28. Monitoring Phase Begins
New telemetry arrives.
The system checks:
drift_score = ...
to determine whether the environment changed significantly.
29. Drift Detection
If operational behavior changes:
new traffic patterns
new failures
new dependencies
the model may become stale.
30. Retraining Triggered
If drift exceeds threshold:
train_model(...)
runs again using newer data.
This is critical because:
graph intelligence decays over time
if operational systems evolve.
Final Conceptual Flow
Telemetry
↓
Graph Attention
↓
Dependency-Aware Embeddings
↓
Temporal Learning
↓
Risk Prediction
↓
Loss Calculation
↓
Backpropagation
↓
Weight Updates
↓
Monitoring
↓
Retraining
Most Important Insight
The model is not simply learning:
“Which node is unhealthy?”
It is gradually learning:
“How operational degradation propagates
through connected systems over time.”
That is the real significance of Temporal Graph Learning.
No comments:
Post a Comment