Moment Neural Network

Connecting Deep Learning and Spiking Neural Networks Through Second-Order Statistics

Most neural networks treat neural activity as a deterministic value: one layer emits an activation vector, and the next layer transforms it. This abstraction is useful, but it leaves out a central feature of biological neural systems. Real neurons do not produce fixed values. They spike irregularly, fluctuate across trials, and often vary together through weak but meaningful correlations.

The moment-neural-network project addresses that gap. It implements Moment Neural Networks, or MNNs, a class of PyTorch-style models that propagate not only mean neural activity, but also second-order statistics such as variance and covariance. In practical terms, an MNN learns both “what is the average firing-rate-like signal?” and “how do neurons fluctuate together?”

The project is grounded in stochastic neural computing and spiking neural network research. The associated paper, Learning and inference with correlated neural variability, frames correlated neural variability not merely as noise to suppress, but as a computational resource for coding, learning, and inference. This repository turns that idea into a trainable software framework.

The Problem

Conventional artificial neural networks usually track only first-order quantities: activations, logits, or firing-rate-like values. That is often enough for classification and regression, but it is not enough to describe the statistics of a spiking neural population.

In real spiking systems, several quantities matter:

  • trial-to-trial variability in spike counts;
  • Fano factors and other measures of firing variability;
  • weak pairwise correlations between neurons;
  • the effect of stochastic input and internal process noise on decision making;
  • the ability to recover an actual spiking neural network from a trained model.

MNNs represent each layer state as a pair:

(u, cov)

Here, u is the mean activity and cov is the covariance matrix. Linear layers, batch normalization, nonlinear activation, and loss functions are all defined around this pair. The model can therefore optimize not only the mean output, but also uncertainty and correlation structure.

This puts MNNs in a useful middle ground. They are not ordinary deterministic ANNs, and they are not directly trained SNNs using surrogate gradients. They are differentiable statistical surrogate models for the firing statistics of SNNs. Once trained, their parameters can be mapped back to a corresponding spiking model.

From Activation to Moment Activation

A standard activation function maps a number to another number. ReLU, for example, clips negative values to zero.

An MNN activation is different. It maps input-current statistics to output-spike statistics. Instead of asking what a neuron’s deterministic activation should be, it asks:

If a leaky integrate-and-fire neuron receives random input current with a given mean, variance, and correlation structure, what are the resulting spike-train moments?

The implementation is mainly in:

  • mnn/mnn_core/mnn_utils.py
  • mnn/mnn_core/mnn_pytorch.py
  • mnn/mnn_core/nn/activation.py

The main entry point is OriginMnnActivation. It receives (u, cov), converts covariance into standard deviations and correlations, applies a custom autograd moment activation, and reconstructs the output covariance.

This is the central technical move of the project. It makes stochastic spiking-neuron dynamics usable inside a differentiable deep learning training pipeline.

How Layers Propagate Covariance

The main feedforward MNN layer is LinearDuo, implemented in mnn/mnn_core/nn/linear.py. It transforms the mean and covariance together:

u   -> W u + b
cov -> W cov W^T

This is the natural transformation rule for a random variable passing through a linear map. The mean transforms linearly, while the covariance transforms bilinearly.

A standard MNN MLP block is built from:

LinearDuo
CustomBatchNorm1D
OriginMnnActivation

The model wrappers live in mnn/models/mlp.py. The most important class is MnnMlp, which builds a feedforward MNN from configuration arguments.

Training Still Looks Like PyTorch

One practical strength of the repository is that the outer workflow remains close to normal PyTorch training, even though the internal representation carries second-order statistics.

There are two main ways to use the project.

The first is the wrapper workflow. You describe the model, optimizer, loss, dataset, transforms, and training options in a YAML file, then call the general training pipeline. The MNIST example is in:

examples/mnist/

with the config file:

examples/mnist/mnist_config.yaml

The training pipeline is implemented under:

mnn/utils/training_tools/

The second path is the vanilla PyTorch workflow. You import MNN layers directly, define your own model, and write your own training loop. This is better for research modifications: custom losses, altered activations, new model structures, or integration into an existing experiment.

The notebook docs/tutorials/tutorial_training_mnn_vanilla.ipynb demonstrates this direct style.

Recovering an SNN From an MNN

One of the main design goals is that a trained MNN can recover the corresponding spiking neural network without an additional fine-tuning stage.

The intended workflow is:

  1. Train an MNN with PyTorch.
  2. Save the model parameters and config.
  3. Build the corresponding SNN from the same config.
  4. Load the trained MNN parameters into the SNN structure.
  5. Simulate the SNN with Poisson or Gaussian inputs.
  6. Record spike trains, spike counts, mean firing rates, and covariance.
  7. Compare the SNN statistics with the MNN prediction.

The relevant code is under:

mnn/snn/
mnn/snn/base/

mnn/snn/base/neurons.py implements LIF neurons. mnn/snn/base/currents.py implements Poisson and Gaussian current generators. mnn/snn/functional.py provides validation utilities such as MnnSnnValidate.

This step matters because the MNN is not only a mathematical abstraction over moments. It is designed to be converted back into an executable spiking simulation.

Quick Start

Install the project through pip:

pip install moment-neural-network

Alternatively, install the project from the repository root:

python -m pip install -e .

Create output and data directories:

mkdir -p checkpoint data

Run the MNIST example:

python examples/mnist/mnist.py --config=examples/mnist/mnist_config.yaml

One important detail: the current examples/mnist/mnist.py has SNN simulation enabled in main(), while train_mnist(config) is commented out. For a first run, change main() to train:

def main():
    config = utils.training_tools.deploy_config()
    train_mnist(config)
    # mnn2snn_simulation(config)

After training, switch back to SNN simulation:

def main():
    config = utils.training_tools.deploy_config()
    # train_mnist(config)
    mnn2snn_simulation(config)

By default, training outputs are written to:

checkpoint/mnist/

Typical outputs include:

  • mnn_net_config.yaml: the saved experiment config;
  • mnn_net_log.txt: training and validation logs;
  • mnn_net.pth: the latest checkpoint;
  • mnn_net_best_model.pth: the best validation checkpoint;
  • mnn_net_snn_validate_result/: SNN simulation results.

Minimal Direct Layer Example

You can also use the MNN layers directly:

import torch
from mnn.mnn_core.nn.linear import LinearDuo
from mnn.mnn_core.nn.custom_batch_norm import CustomBatchNorm1D
from mnn.mnn_core.nn.activation import OriginMnnActivation
from mnn.mnn_core.nn.criterion import CrossEntropyOnMean

class SimpleMNN(torch.nn.Module):
    def __init__(self, input_size=784, hidden_size=100, output_size=10):
        super().__init__()
        self.linear = LinearDuo(input_size, hidden_size)
        self.bn = CustomBatchNorm1D(hidden_size)
        self.act = OriginMnnActivation()
        self.readout = LinearDuo(hidden_size, output_size, bias=True)

    def forward(self, inputs):
        u, cov = inputs
        u, cov = self.linear(u, cov)
        u, cov = self.bn(u, cov)
        u, cov = self.act(u, cov)
        return self.readout(u, cov)

def encode_poisson_images(images, scale=1.0):
    mean = torch.flatten(images, start_dim=1) * scale
    cov = torch.diag_embed(torch.abs(mean))
    return mean, cov

model = SimpleMNN()
criterion = CrossEntropyOnMean()

The encode_poisson_images helper treats image pixels as independent Poisson firing rates, so the input covariance is diagonal. This is the same basic idea used by input_prepare: flatten_poisson in the MNIST config.

How to Read the Docs

A practical reading order is:

  1. docs/tutorials/tutorial_moment_activation.ipynb

    Start here to understand how moment activation maps LIF input statistics to output statistics.

  2. docs/tutorials/tutorial_training_mnn.ipynb

    Use this to understand the YAML-based training workflow and the SNN reconstruction pipeline.

  3. docs/tutorials/tutorial_training_mnn_vanilla.ipynb

    Read this if you want to use MNN layers directly in ordinary PyTorch code.

  4. docs/tutorials/tutorial_EI_network.ipynb

    This shows how MNNs can model a classic excitatory/inhibitory recurrent circuit and capture both population-level and neuron-level firing statistics.

  5. publications/

    Use these folders for paper-specific experiments and reproduction work.

Repository Map

mnn/
  mnn_core/          # moment activation, layers, losses, functional utilities
  models/            # MNN/ANN/SNN-style model wrappers
  snn/               # MNN-to-SNN conversion and simulation
  utils/             # training pipeline, dataloaders, preprocessing
  analysis/          # analysis and visualization helpers

docs/                # Sphinx docs and tutorial notebooks
examples/mnist/      # minimal MNIST example
publications/        # paper-specific experiment code

Who This Project Is For

This repository is most useful if you want to:

  • study the computational role of correlated neural variability;
  • train models that can be recovered as SNNs;
  • bring firing-rate covariance into a PyTorch workflow;
  • reproduce or extend MNN-related papers;
  • compare ANN, MNN, and SNN behavior under uncertainty and stochastic dynamics.

If the goal is only to train a high-accuracy image classifier, a conventional CNN or Transformer is more direct. The value of MNNs is different: they are useful when variability, correlation, spiking dynamics, or neuromorphic implementation are part of the research question.

Current Limitations

Full covariance scales quadratically with layer width, so large models can become memory intensive. The MLP path is the most mature part of the repository. Convolution support exists, but it is less developed than the MLP stack.

Some publication scripts are tied to specific datasets, paths, or experiment environments, so they may need configuration edits before reproducing results. SNN simulation is also slower than an MNN forward pass because it advances a stochastic process over many time steps. In practice, running_time, dt, num_trials, and hardware resources need to be chosen carefully.

Summary

moment-neural-network makes the stochastic statistics of spiking neural systems trainable in a deep learning framework. Instead of learning only the mean output of a network, it also learns uncertainty and correlation structure.

From an engineering perspective, the repository provides a complete path: MNN layers, losses, training wrappers, tutorials, an MNIST example, MNN-to-SNN conversion, SNN simulation, and publication code. From a research perspective, it gives a concrete tool for studying how correlated neural variability can shape learning, inference, and neural coding.

The central idea is simple but consequential: neural variability is not just a nuisance to remove. In this project, it becomes part of the model.