MM-DLS

Multi-task deep learning based on PET/CT images for the diagnosis and prognosis prediction of advanced non-small cell lung cancer

Overview

MM-DLS is a multi-modal, multi-task deep learning framework for the diagnosis, staging, and prognosis prediction of advanced non-small cell lung cancer (NSCLC). It integrates multi-source data including CT images, PET metabolic parameters, and clinical information to provide a unified, non-invasive decision-making tool for personalized treatment planning.

This repository implements the full MM-DLS pipeline, consisting of:

  • Lung-lesion segmentation with cross-attention transformer
  • Multi-modal feature fusion (CT, PET, Clinical)
  • Multi-task learning: Pathological classification, TNM staging, DFS and OS survival prediction
  • Cox proportional hazards survival loss

The framework supports both classification (adenocarcinoma vs squamous cell carcinoma) and survival risk prediction tasks, and has been validated on large-scale multi-center clinical datasets.


Key Features

  • Multi-modal fusion: Combines CT-based imaging features, PET metabolic biomarkers (SUVmax, SUVmean, SUVpeak, TLG, MTV), and structured clinical variables (age, sex, smoking status, smoking duration, smoking cessation history, tumor size).
  • Multi-task learning: Simultaneous optimization for:
    • Histological subtype classification (LUAD vs LUSC)
    • TNM stage classification (I-II, III, IV)
    • Disease-free survival (DFS) prediction
    • Overall survival (OS) prediction
  • Attention-based feature fusion: Transformer cross-attention module to integrate lung-lesion spatial information.
  • Survival modeling: Incorporates Cox Proportional Hazards loss for survival time prediction.
  • Flexible data simulation and loading: Includes utilities for synthetic data generation and multi-slice 2D volume processing.

Architecture

The overall MM-DLS system consists of:

Python PyTorch CUDA License Status

  1. Segmentation Module (LungLesionSegmentor):

    • Shared ResNet encoder to extract features from CT images.
    • Dual decoders for lung and lesion segmentation.
    • Transformer-based cross-attention module for enhanced spatial feature interaction between lung and lesion regions.
  2. Feature Encoders:

    • LesionEncoder: 2D convolutional encoder for lesion patches.
    • SpaceEncoder: 2D convolutional encoder for lung-space contextual patches.
  3. Attention Fusion Module:

    • LesionAttentionFusion: Multi-head attention to fuse lesion and lung features into compact patient-level representations.
  4. Patient-Level Fusion Model (PatientLevelFusionModel):

    • Fully connected network that combines imaging, PET, and clinical features.
    • Outputs classification logits, DFS and OS risk scores.
  5. Loss Functions:

    • Binary cross-entropy loss for classification.
    • Cox proportional hazards loss (CoxPHLoss) for survival prediction.

Code Structure

  • ModelLesionEncoder.py: Lesion image encoder extracting discriminative features from multi-slice tumor regions.
  • ModelSpaceEncoder.py: Lung space encoder modeling anatomical and spatial context beyond the lesion.
  • LesionAttentionFusion.py: Attention-based fusion module for adaptive integration of lesion and spatial features.
  • ClinicalFusionModel.py: Patient-level fusion network combining imaging features, radiomics, PET signals, and clinical variables.
  • HierMM_DLS.py:Core hierarchical multimodal deep learning model supporting multi-task learning: (1)Subtype classification; (2)TNM stage prediction; (3)DFS and OS modeling
  • CoxphLoss.py: Cox proportional hazards loss for survival modeling with censored data.
  • PatientDataset.py:Patient dataset loader supporting imaging, radiomics, PET, clinical variables, survival outcomes, and treatment labels.
  • LungLesionSegmentation.py: Lung-lesion segmentation model
  • ImageDataLoader.py: Image preprocessing and loading utilities for multi-slice inputs.
  • plot_results.py: Visualization utilities for Kaplan–Meier curves, hazard ratios, and survival analysis results.

Data Format

The input data is organized per patient as follows:

Imaging Data:

  • CT slices (PNG format)
  • Lung masks (binary masks, PNG)
  • Lesion masks (binary masks, PNG)
  • Slices grouped per patient ID

Tabular Data:

  • Radiomics features: 128-dimensional vector (PyRadiomics extracted)
  • PET features: [SUVmax, SUVmean, SUVpeak, TLG, MTV]
  • Clinical features: [Age, Sex, Smoking Status, Smoking Duration, Smoking Cessation, Tumor Diameter]
  • Survival data: DFS time/event, OS time/event
  • Classification label: LUAD (0) or LUSC (1)

Simulated data utilities are provided for experimentation and reproducibility.


Installation

# Clone repository
conda create -n mm_dls python=3.10 -y
conda activate mm_dls
git clone https://github.com/your_username/MM-DLS-NSCLC.git

Install dependencies

pip install -r requirements.txt

Usage

🔽 Download Pretrained Models

Pretrained MM-DLS models are available for direct download:

  • MM-DLS (Full multimodal, best checkpoint)
    ⬇️ Download Pretrained Model Size 1.3 MB The MM-DLS model is intentionally lightweight (~1.3 MB), as it employs compact CNN encoders and MLP-based multimodal fusion rather than large pretrained backbones, enabling efficient deployment and fast inference.

After downloading, place the model files under the ./MODEL/ directory:

Training:

python train_patient_model.py

Evaluation:

python test.py

Example Forward Pass:

python run_sample.ipynb

Model Performance (from publication)

Histological Subtype Classification:

AUC: 0.85 ~ 0.92 across cohorts

AP: 0.81 ~ 0.86

TNM Stage Prediction:

AUC: Stage I-II (0.86 ~ 0.96), Stage III (0.85 ~ 0.95), Stage IV (0.83 ~ 0.95)

AP and calibration maintained across internal and external sets

DFS & OS Prognosis:

C-index: up to 0.75

Time-dependent AUC (1/2/3 years): 0.77 ~ 0.91

Brier score: consistently < 0.2 for DFS and < 0.3 for OS

Superior to single modality models (clinical-only or imaging-only)

Reference

Please cite our original publication when using this work:

License This project is licensed under the MIT License.

⚠️ Notice: The pretrained model is shared solely for research validation purposes and should not be used, distributed, or cited before the associated study is formally published.

Contact For any questions or collaborations, please contact:

Dr. Fang Dai: [email protected]

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support