Mamba and PyTorch: Integration and Setup

The Mamba state space model architecture is implemented primarily through a PyTorch-based ecosystem, requiring specific CUDA extensions, dependency chains, and installation configurations to function correctly. This page covers the integration surface between Mamba and PyTorch — including hardware-aware kernel dependencies, environment setup, package variants, and the practical boundaries where different installation paths apply. Practitioners deploying Mamba for sequence modeling, NLP, or genomics workloads will encounter distinct setup requirements that differ meaningfully from standard transformer-based pipelines.

Definition and scope

Mamba's PyTorch integration spans two distinct layers: the Python-level model code and the low-level CUDA kernels that implement selective scan operations. The reference implementation, published by Albert Gu and Tri Dao on GitHub under the state-spaces/mamba repository, depends on torch (PyTorch) along with causal-conv1d and mamba-ssm as the primary packages. The mamba-ssm package version 1.x targets PyTorch 1.13 through 2.x and requires CUDA 11.6 or higher for kernel compilation.

The scope of this integration is narrower than a general deep learning framework integration. Mamba does not use PyTorch's standard nn.MultiheadAttention or transformer encoder APIs. Instead, it relies on custom CUDA extensions — specifically a selective scan kernel and a causal depthwise convolution kernel — that must be compiled or installed from pre-built wheels. This distinguishes Mamba from architectures that run on PyTorch's native operators alone.

The Mamba architecture overview describes the SSM block structure that these kernels serve. For broader context on the state space model family, the State Space Models Explained reference defines the mathematical framework underpinning these operations.

The Python package mamba-ssm is listed on PyPI, maintained under the Apache 2.0 license, and tracked through the official state-spaces/mamba GitHub repository. PyTorch itself is governed by the PyTorch Foundation, an independent entity under the Linux Foundation umbrella since 2022.

How it works

Installing Mamba in a functional PyTorch environment involves a sequenced dependency resolution across four components:

  1. CUDA toolkit installation — CUDA 11.6 minimum; CUDA 12.x is supported for mamba-ssm versions 1.2.0 and later. The CUDA version must match the PyTorch CUDA build variant.
  2. PyTorch installation — Standard pip or conda install targeting the correct CUDA-enabled wheel (e.g., torch==2.1.0+cu118).
  3. causal-conv1d package — A prerequisite for mamba-ssm, implementing the depthwise convolution component. Installed via pip install causal-conv1d>=1.1.0.
  4. mamba-ssm package — Installed via pip install mamba-ssm. On machines with compatible CUDA toolchains, this triggers JIT compilation or uses pre-built binary wheels.

The selective scan operation — the computational core of Mamba — is exposed through mamba_ssm.ops.selective_scan_interface. This interface wraps the CUDA kernel and falls back to a pure PyTorch reference implementation when CUDA is unavailable, though that fallback is substantially slower (benchmarks in the original Gu & Dao 2023 paper show hardware-aware implementations running at roughly 3–4× the throughput of naive PyTorch equivalents at sequence length 2048).

The MambaLMHeadModel and MixerModel classes follow PyTorch's nn.Module convention, making them compatible with standard PyTorch training loops, DataLoader, and optimizer APIs. Gradient checkpointing, mixed-precision training via torch.cuda.amp, and distributed training through torch.distributed all function with Mamba modules using standard PyTorch patterns. The Mamba hardware-aware algorithms page documents the kernel design rationale in detail.

Common scenarios

Scenario 1: Clean GPU environment setup
A researcher with an A100-80GB GPU running CUDA 12.1 installs PyTorch 2.1.0 with CUDA 12.1 support, then installs causal-conv1d and mamba-ssm sequentially. Pre-built wheels are available for this combination, so no local compilation is required. Total installation time is under 5 minutes.

Scenario 2: Pre-trained model loading via Hugging Face
mamba-ssm integrates with the Hugging Face transformers ecosystem through a compatibility layer. The Mamba Hugging Face page covers the AutoModelForCausalLM path for models hosted on the Hugging Face Hub, which requires mamba-ssm to be present as a backend dependency even when loading via transformers.

Scenario 3: CPU-only or non-CUDA environments
On CPU-only machines, mamba-ssm installs but the CUDA kernels are absent. The pure-PyTorch reference path activates automatically. This path is functional for small-scale testing but is not suitable for training or inference at production sequence lengths. The Mamba inference optimization page covers the performance gap in quantitative terms.

Scenario 4: Docker and containerized deployments
NVIDIA's NGC container registry provides base images with CUDA pre-configured. Teams using containerized ML infrastructure typically build from nvcr.io/nvidia/pytorch base images (tagged by CUDA and PyTorch version) and install mamba-ssm at container build time to avoid runtime compilation.

Decision boundaries

Choosing between installation approaches follows hardware and environment constraints:

Condition Recommended path
GPU with CUDA 11.6–12.x, PyTorch pre-installed pip install causal-conv1d mamba-ssm from PyPI
Conda environment management Use pip within conda; conda-forge does not maintain mamba-ssm wheels
HPC cluster without root access Build from source using pip install --user with module-loaded CUDA
Inference-only, no training Consider mamba-ssm[causal-conv1d] minimal install or pre-quantized checkpoints
Mamba2 architecture Requires mamba-ssm>=2.0.0; see Mamba2 improvements

The most common failure mode during setup is a CUDA version mismatch between the installed PyTorch build and the system CUDA toolkit. PyTorch's torch.version.cuda attribute and the system nvcc --version output must align to the same major.minor version for kernel compilation to succeed.

For practitioners beginning to orient in this ecosystem, the Mamba reference index provides a structured entry point into implementation, benchmarking, and domain application resources. Fine-tuning workflows that build on top of a correctly installed base environment are covered in Mamba fine-tuning.

References