PyTorch Geometric for Beginners: Create, Visualize, and Train on Graph Data
If you’re just starting out with graph neural networks using PyTorch Geometric (PyG), it can feel overwhelming. This post shares my practical notes on how to create a graph, visualize it, set up the dataset, and build a simple model using PyG. I’ve kept it minimal and functional—just enough to get you up and running.
1. What model.train()
and model.eval()
Really Mean
model.train() # Tells the model you’re training (activates dropout, batch norm, etc.)
model.eval() # Tells the model you’re evaluating (deactivates training-only layers)
These do not train or evaluate the model themselves. They only change the internal behavior of layers.
2. Create a Single Graph
import torch
from torch_geometric.data import Data
# Define edges in COO format
edges = [[0, 1, 1, 2, 0, 2], [1, 0, 2, 1, 2, 0]]
edge_index = torch.tensor(edges, dtype=torch.long)
# Node features (3 nodes, 2 features each)
x = torch.tensor([[2, 3], [-1, 2], [-4, 1]], dtype=torch.float)
# Create the data object
data = Data(x=x, edge_index=edge_index)
print(data)
# Output: Data(x=[3, 2], edge_index=[2, 6])
You can freely add new attributes to the Data
object as needed.
3. Visualize the Graph
import networkx as nx
from torch_geometric.utils import to_networkx
G = to_networkx(data)
nx.draw(G, with_labels=True)
Visualizing the structure often helps understand how nodes are connected.
4. Generate Adjacency Matrix
from torch_geometric.transforms import ToDense
import copy
dataTest = copy.deepcopy(data)
ToDense(num_nodes=3)(dataTest)
print(dataTest.adj)
Useful for checking consistency between the adjacency matrix and the feature matrix.
5. Random Train/Val/Test Split
from torch_geometric.transforms import RandomNodeSplit
split = RandomNodeSplit(split="train_rest", num_val=0.3, num_test=0.6)
data = split(data)
print(torch.sum(data.train_mask).item())
print(torch.sum(data.val_mask).item())
print(torch.sum(data.test_mask).item())
Adds train_mask
, val_mask
, and test_mask
to your Data
object.
6. Build a Simple GCN Model
from torch.nn import Linear
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self, in_features, out_classes):
super().__init__()
self.conv1 = GCNConv(in_features, 4)
self.conv2 = GCNConv(4, 4)
self.conv3 = GCNConv(4, 2)
self.classifier = Linear(2, out_classes)
def forward(self, x, edge_index):
h = self.conv1(x, edge_index).tanh()
h = self.conv2(h, edge_index).tanh()
h = self.conv3(h, edge_index).tanh()
return self.classifier(h), h
7. Training Function
import torch.nn.functional as F
def train(model, data, optimizer):
model.train()
optimizer.zero_grad()
out, _ = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
Use CrossEntropyLoss
if your model outputs raw logits.
8. Run Training Loop
model = GCN(data.num_features, num_classes=4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(200):
loss = train(model, data, optimizer)
if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {loss}")
9. Count Trainable Parameters
print("Trainable params:", sum(p.numel() for p in model.parameters() if p.requires_grad))
Final Thoughts
This note doesn’t cover everything but touches the key steps: how to create a graph, visualize it, prepare data, and build a simple GCN model. If you’re just starting out with PyG, I hope this helps get you past the initial barrier.
👋 About Me
Hi, I’m Shuvangkar Das, a power systems researcher with a Ph.D. in Electrical Engineering from Clarkson University. I work at the intersection of power electronics, DER, IBR, and AI — building greener, smarter, and more stable grids. Currently, I’m a Research Engineer at EPRI (though everything I share here reflects my personal experience, not my employer’s views).
Over the years, I’ve worked on real-world projects involving large scale EMT simulation and firmware development for grid-forming and grid following inverter and reinforcement learning (RL). I also publish technical content and share hands-on insights with the goal of making complex ideas accessible to engineers and researchers.
📺 Subscribe to my YouTube channel, where I share tutorials, code walk-throughs, and research productivity tips.
📚References
[[PyG Fundamental]] [[PyG Resource]]
Leave a comment