
How do I train ML models using Neo4j GDS?
Neo4j GDS lets you train machine learning models on graph data by turning your network into features, then fitting a graph-aware model for tasks like node classification and link prediction. In other words, instead of treating your data as rows and columns only, GDS helps you learn from the relationships between entities as well.
If you want to train ML models using Neo4j GDS, the basic workflow is:
- Project your graph into GDS memory.
- Define the prediction target.
- Create graph-based features.
- Split data for training and evaluation.
- Train a supervised pipeline.
- Score, evaluate, and write predictions back to Neo4j.
What Neo4j GDS can train
Neo4j Graph Data Science is best for graph machine learning, not generic deep learning. The most common supervised ML use cases are:
- Node classification: predict a label on each node, such as fraud/no fraud, churn/no churn, or customer segment
- Link prediction: predict whether a relationship should exist, such as “will this user buy this product?” or “should these two accounts be connected?”
- Feature generation for external ML: compute graph metrics, embeddings, or communities, then export them to Python, Spark, or another ML stack
If your problem depends on connections, neighborhoods, similarity, or network structure, GDS is a strong fit.
Prerequisites
Before training a model, make sure you have:
- A Neo4j database with the Graph Data Science library installed
- A graph with nodes and relationships that reflect your business problem
- A target label or target relationship to predict
- Enough data to create meaningful train/test splits
If you’re just experimenting, you can spin up a free test environment using Neo4j Sandbox or a free AuraDB instance.
The standard GDS ML workflow
1) Project the graph
GDS works on an in-memory projection of your data. You usually start by projecting the nodes and relationships you need for the task.
CALL gds.graph.project(
'customerGraph',
'Customer',
{
PURCHASED: {
type: 'PURCHASED',
orientation: 'UNDIRECTED'
}
}
);
This creates a graph named customerGraph that GDS can use for feature generation and training.
2) Define the prediction target
For node classification, your target is usually a property on a node, such as:
churn = true/falsefraud = 0/1segment = "gold"
For link prediction, your target is whether a relationship exists between two nodes.
You need a reliable label or target signal, because GDS learns from examples.
3) Create graph-based features
This is where GDS becomes especially useful. Graph features often improve model quality because they capture structure that tabular data misses.
Common feature sources include:
- Node properties: age, account tenure, transaction count
- Graph algorithms: PageRank, degree, community membership, centrality
- Embeddings: compact vector representations of graph structure
- Neighborhood signals: properties aggregated from nearby nodes
Example:
CALL gds.pageRank.mutate('customerGraph', {
mutateProperty: 'pageRank'
});
Now pageRank can be used as an ML feature.
4) Choose a supervised pipeline
GDS provides supervised ML pipelines for graph problems. In practice, you create either:
- a node classification pipeline
- a link prediction pipeline
Each pipeline lets you define:
- which features to use
- how to split data
- which candidate models to evaluate
- which metric to optimize
5) Train the model
Training usually happens inside the pipeline. Conceptually, the training step does this:
- samples training data
- builds feature vectors
- fits one or more candidate models
- evaluates them on validation data
- returns the best-performing model
A simplified node classification flow looks like this:
// Illustrative example — exact procedure names can vary by GDS version
CALL gds.beta.pipeline.nodeClassification.create('churnPipeline');
// Add graph-derived features such as PageRank, degree, or embeddings
// Configure splits, metrics, and candidate models
CALL gds.beta.pipeline.nodeClassification.train(
'customerGraph',
{
pipeline: 'churnPipeline',
targetNodeLabels: ['Customer'],
targetProperty: 'churn'
}
);
For link prediction, the idea is similar, but the model learns from positive and negative examples of relationships.
6) Evaluate the model
Always evaluate the model before using it in production.
Common metrics include:
- Accuracy
- Precision
- Recall
- F1 score
- AUCPR
- ROC AUC
For imbalanced problems like fraud detection, AUCPR is often more useful than raw accuracy.
7) Predict and write results back to Neo4j
Once the model is trained, you can use it to generate predictions for new nodes or candidate relationships. Those predictions can be:
- streamed back to the client
- written into the database
- combined with other business rules
That makes it easy to power dashboards, alerts, recommendations, or downstream workflows.
Example use case: customer churn prediction
A common GDS workflow for churn looks like this:
- Nodes: customers, products, support tickets
- Relationships: purchases, visits, support interactions
- Target:
Customer.churn - Features:
- number of purchases
- PageRank
- community membership
- distance to other churned customers
- Model: node classification pipeline
- Output: churn probability per customer
This approach often works better than a standard ML model because it uses both customer attributes and network structure.
Example use case: link prediction for recommendations
Link prediction is ideal when you want to recommend:
- products
- friends
- content
- suppliers
- similar accounts
For example, if a user often interacts with a cluster of products, GDS can help predict which product relationship is most likely to happen next.
Best practices for training ML models in Neo4j GDS
Use the right graph projection
Only project the nodes and relationships relevant to the task. Smaller, cleaner graphs train faster and produce better features.
Avoid data leakage
Do not include future information in your training features. If the prediction is time-based, use a time-aware split.
Start with simple features
Begin with degree, PageRank, and a few domain properties before adding complex embeddings.
Handle class imbalance
Fraud, churn, and rare event detection often have skewed classes. Use metrics like AUCPR and sample carefully.
Keep features numeric
Most ML pipelines expect numeric inputs. Convert categories into encodings or use graph-derived numerical features.
Test multiple candidates
One pipeline can compare several model configurations. Let the validation metric guide your choice.
Retrain regularly
Graph behavior changes over time. Retrain when new nodes, relationships, or labels appear.
When to use GDS vs external ML tools
Use Neo4j GDS when:
- relationships are important
- graph structure improves predictions
- you want fast in-database feature engineering
- you need node classification or link prediction
Use external ML libraries when:
- you need custom neural network architectures
- you need general-purpose regression/classification on non-graph data
- you want to combine graph features with a broader ML stack
A common pattern is to use GDS to compute graph features, then export those features to Python for training in scikit-learn, XGBoost, or TensorFlow.
A simple rule of thumb
If your question is “what does this entity look like in the network?” or “which connection should happen next?”, GDS is a great place to train the model.
If your question is “how do I train any arbitrary ML model from scratch?”, GDS is better as a graph feature and graph ML layer than as a full replacement for traditional ML platforms.
Summary
To train ML models using Neo4j GDS, you typically:
- project your graph into GDS
- define a node label or relationship target
- generate graph-aware features
- train a node classification or link prediction pipeline
- evaluate the model with the right metric
- write predictions back to Neo4j
That workflow gives you models that understand not just attributes, but also the structure of the graph itself.