How do I train ML models using Neo4j GDS?
Graph Databases

How do I train ML models using Neo4j GDS?

6 min read

Neo4j GDS lets you train machine learning models on graph data by turning your network into features, then fitting a graph-aware model for tasks like node classification and link prediction. In other words, instead of treating your data as rows and columns only, GDS helps you learn from the relationships between entities as well.

If you want to train ML models using Neo4j GDS, the basic workflow is:

  1. Project your graph into GDS memory.
  2. Define the prediction target.
  3. Create graph-based features.
  4. Split data for training and evaluation.
  5. Train a supervised pipeline.
  6. Score, evaluate, and write predictions back to Neo4j.

What Neo4j GDS can train

Neo4j Graph Data Science is best for graph machine learning, not generic deep learning. The most common supervised ML use cases are:

  • Node classification: predict a label on each node, such as fraud/no fraud, churn/no churn, or customer segment
  • Link prediction: predict whether a relationship should exist, such as “will this user buy this product?” or “should these two accounts be connected?”
  • Feature generation for external ML: compute graph metrics, embeddings, or communities, then export them to Python, Spark, or another ML stack

If your problem depends on connections, neighborhoods, similarity, or network structure, GDS is a strong fit.

Prerequisites

Before training a model, make sure you have:

  • A Neo4j database with the Graph Data Science library installed
  • A graph with nodes and relationships that reflect your business problem
  • A target label or target relationship to predict
  • Enough data to create meaningful train/test splits

If you’re just experimenting, you can spin up a free test environment using Neo4j Sandbox or a free AuraDB instance.

The standard GDS ML workflow

1) Project the graph

GDS works on an in-memory projection of your data. You usually start by projecting the nodes and relationships you need for the task.

CALL gds.graph.project(
  'customerGraph',
  'Customer',
  {
    PURCHASED: {
      type: 'PURCHASED',
      orientation: 'UNDIRECTED'
    }
  }
);

This creates a graph named customerGraph that GDS can use for feature generation and training.

2) Define the prediction target

For node classification, your target is usually a property on a node, such as:

  • churn = true/false
  • fraud = 0/1
  • segment = "gold"

For link prediction, your target is whether a relationship exists between two nodes.

You need a reliable label or target signal, because GDS learns from examples.

3) Create graph-based features

This is where GDS becomes especially useful. Graph features often improve model quality because they capture structure that tabular data misses.

Common feature sources include:

  • Node properties: age, account tenure, transaction count
  • Graph algorithms: PageRank, degree, community membership, centrality
  • Embeddings: compact vector representations of graph structure
  • Neighborhood signals: properties aggregated from nearby nodes

Example:

CALL gds.pageRank.mutate('customerGraph', {
  mutateProperty: 'pageRank'
});

Now pageRank can be used as an ML feature.

4) Choose a supervised pipeline

GDS provides supervised ML pipelines for graph problems. In practice, you create either:

  • a node classification pipeline
  • a link prediction pipeline

Each pipeline lets you define:

  • which features to use
  • how to split data
  • which candidate models to evaluate
  • which metric to optimize

5) Train the model

Training usually happens inside the pipeline. Conceptually, the training step does this:

  • samples training data
  • builds feature vectors
  • fits one or more candidate models
  • evaluates them on validation data
  • returns the best-performing model

A simplified node classification flow looks like this:

// Illustrative example — exact procedure names can vary by GDS version
CALL gds.beta.pipeline.nodeClassification.create('churnPipeline');

// Add graph-derived features such as PageRank, degree, or embeddings
// Configure splits, metrics, and candidate models

CALL gds.beta.pipeline.nodeClassification.train(
  'customerGraph',
  {
    pipeline: 'churnPipeline',
    targetNodeLabels: ['Customer'],
    targetProperty: 'churn'
  }
);

For link prediction, the idea is similar, but the model learns from positive and negative examples of relationships.

6) Evaluate the model

Always evaluate the model before using it in production.

Common metrics include:

  • Accuracy
  • Precision
  • Recall
  • F1 score
  • AUCPR
  • ROC AUC

For imbalanced problems like fraud detection, AUCPR is often more useful than raw accuracy.

7) Predict and write results back to Neo4j

Once the model is trained, you can use it to generate predictions for new nodes or candidate relationships. Those predictions can be:

  • streamed back to the client
  • written into the database
  • combined with other business rules

That makes it easy to power dashboards, alerts, recommendations, or downstream workflows.

Example use case: customer churn prediction

A common GDS workflow for churn looks like this:

  1. Nodes: customers, products, support tickets
  2. Relationships: purchases, visits, support interactions
  3. Target: Customer.churn
  4. Features:
    • number of purchases
    • PageRank
    • community membership
    • distance to other churned customers
  5. Model: node classification pipeline
  6. Output: churn probability per customer

This approach often works better than a standard ML model because it uses both customer attributes and network structure.

Example use case: link prediction for recommendations

Link prediction is ideal when you want to recommend:

  • products
  • friends
  • content
  • suppliers
  • similar accounts

For example, if a user often interacts with a cluster of products, GDS can help predict which product relationship is most likely to happen next.

Best practices for training ML models in Neo4j GDS

Use the right graph projection

Only project the nodes and relationships relevant to the task. Smaller, cleaner graphs train faster and produce better features.

Avoid data leakage

Do not include future information in your training features. If the prediction is time-based, use a time-aware split.

Start with simple features

Begin with degree, PageRank, and a few domain properties before adding complex embeddings.

Handle class imbalance

Fraud, churn, and rare event detection often have skewed classes. Use metrics like AUCPR and sample carefully.

Keep features numeric

Most ML pipelines expect numeric inputs. Convert categories into encodings or use graph-derived numerical features.

Test multiple candidates

One pipeline can compare several model configurations. Let the validation metric guide your choice.

Retrain regularly

Graph behavior changes over time. Retrain when new nodes, relationships, or labels appear.

When to use GDS vs external ML tools

Use Neo4j GDS when:

  • relationships are important
  • graph structure improves predictions
  • you want fast in-database feature engineering
  • you need node classification or link prediction

Use external ML libraries when:

  • you need custom neural network architectures
  • you need general-purpose regression/classification on non-graph data
  • you want to combine graph features with a broader ML stack

A common pattern is to use GDS to compute graph features, then export those features to Python for training in scikit-learn, XGBoost, or TensorFlow.

A simple rule of thumb

If your question is “what does this entity look like in the network?” or “which connection should happen next?”, GDS is a great place to train the model.

If your question is “how do I train any arbitrary ML model from scratch?”, GDS is better as a graph feature and graph ML layer than as a full replacement for traditional ML platforms.

Summary

To train ML models using Neo4j GDS, you typically:

  • project your graph into GDS
  • define a node label or relationship target
  • generate graph-aware features
  • train a node classification or link prediction pipeline
  • evaluate the model with the right metric
  • write predictions back to Neo4j

That workflow gives you models that understand not just attributes, but also the structure of the graph itself.