In this article, I'll explain the paper: Zero-shot ECG classification using multimodal learning and test-time clinical knowledge enhancement which was recently accepted at ICML 2024. If you are someone who is passionate about AI and Healthcare, this article is definitely for you!
If you want to start playing with the code right away after reading the article, you can find the whole project at this github link.
The rest of the article will cover the following points:
- What is ECG Classification?
- The Need for Multimodal Learning in ECG Classification
- MERL: Multimodal ECG Representation Learning
- Cross-Modal Alignment in MERL
- Uni-Modal Alignment in MERL
- Test-time Clinical Knowledge Enhancement (CKEPE)
- Results achieved by this framework
What is ECG Classification?
ECG classification is the process of analyzing electrocardiogram (ECG) signals to identify different types of heart conditions. These signals, which represent the electrical activity of the heart over a period of time, are crucial in diagnosing various cardiac disorders. Accurate classification of ECG signals can help in early detection and treatment of heart diseases, potentially saving lives.
The Need for Multimodal Learning in ECG Classification
Multimodal learning is essential in ECG classification due to the complex nature of ECG signals and the rich clinical context found in associated reports. Traditional ECG self-supervised learning (eSSL) methods, which focus solely on ECG signals, often lack the ability to capture high-level semantic information crucial for accurate diagnosis. These methods can also distort semantic information during data augmentation processes, leading to suboptimal performance. By integrating multimodal learning, which combines ECG signals with clinical reports, more comprehensive and accurate representations can be learned. According to the authors, this approach improves the zero-shot classification capability, and makes it possible to classify ECGs without annotated training data, and hence addresses the limitations of eSSL methods.
MERL: Multimodal ECG Representation Learning
MERL (Multimodal ECG Representation Learning) is designed to leverage both ECG signals and their associated clinical reports for representation learning. This framework enables zero-shot ECG classification using text prompts, eliminating the need for annotated training data. MERL incorporates Cross-Modal Alignment (CMA) and Uni-Modal Alignment (UMA) during the training phase to align ECG and text features at a latent level, avoiding the semantic distortion common in traditional eSSL methods. At the test stage, Clinical Knowledge Enhanced Prompt Engineering (CKEPE) is used to dynamically generate descriptive prompts based on external clinical knowledge, further improving classification accuracy. MERL has demonstrated superior performance in various benchmarks, outperforming traditional eSSL methods even with limited data.
Cross-Modal Alignment in MERL
Cross-Modal Alignment (CMA) in MERL is a technique that aligns ECG features with clinical knowledge derived from reports. This alignment is achieved using two encoders: one for ECG signals and one for text reports. These encoders transform ECG-report pairs into a shared latent embedding space. The CMA process then employs contrastive loss to maximize the similarity between matched ECG and report embeddings while minimizing it between unmatched pairs. This ensures that the learned representations are enriched with clinical context, leading to more accurate and robust ECG classification. The process also includes nonlinear projectors to map the embeddings into the same dimensionality, facilitating effective cross-modal learning.
The CMA loss function is mathematically expressed as:
Uni-Modal Alignment in MERL
Uni-Modal Alignment (UMA) complements CMA by focusing on contrastive learning within the ECG signal domain. Unlike traditional methods that rely on input-level data augmentation, UMA uses latent space augmentation to preserve semantic integrity. This involves applying independent dropout operations on ECG embeddings to create positive pairs for contrastive learning. The loss function for UMA encourages the model to learn robust representations by treating the augmented pairs as positives and other combinations as negatives. By minimizing both UMA and CMA losses during training, MERL ensures that the ECG features are both clinically relevant and robust against semantic distortions.
The loss function of UMA is expressed as:
Here M1 and M2 represent the dropout masks applied at the latent representation of ECG.
In summary, the model learns representative ECG features by jointly minimizing UMA loss and CMA loss, and the overall training loss can be written as:
Total Loss = CMA Loss + UMA Loss
Test-time Clinical Knowledge Enhancement (CKEPE)
Test-time Clinical Knowledge Enhancement (CKEPE) is a strategy implemented in the MERL framework to improve zero-shot ECG classification. At test time, CKEPE leverages large language models (LLMs) to dynamically generate descriptive and clinically relevant text prompts. This process involves querying external, expert-verified clinical knowledge databases, such as the Systemized Nomenclature of Medicine — Clinical Terms (SNOMED CT) and SCP-ECG statements, to extract accurate and detailed attributes of cardiac conditions.
Steps in CKEPE:
- Knowledge Extraction: The LLM queries the external databases to gather detailed clinical attributes and subtypes related to the cardiac condition being classified. This step ensures that the prompts are based on verified clinical information, reducing the risk of hallucination and inaccuracies commonly associated with LLM-generated content.
- Prompt Generation: The extracted clinical knowledge is restructured into text prompts that describe the cardiac condition in a detailed and clinically meaningful way. These prompts are designed to be more informative than simple category names or fixed templates used in traditional methods.
- Zero-shot Classification: During the classification process, the ECG embeddings are compared with the embeddings of the generated prompts. The similarity scores between these embeddings are used to classify the ECG signals, leveraging the rich clinical context provided by the enhanced prompts.
Results
In the analysis presented in paper across six different datasets, the authors assessed the performance of zero-shot MERL compared to eSSL approaches on linear probing, as shown the following figure:
Zero-shot MERL, even without prompt enhancement, surpasses the top eSSL method linear probed with 1% additional training data, achieving higher average AUC across six datasets. Furthermore, with CKEPE, zero-shot MERL exceeds the best eSSL method probed with additional 10% data, underscoring the effectiveness of CKEPE and learned ECG representations from MERL.
The authors also present the performances of zero-shot MERL and eSSL performance on individual datasets, as shown below:
Remarkably, zero-shot MERL demonstrates superior performance to eSSL with linear probing across all downstream datasets, even without additional training samples. This underlines MERL's ability to learn robust, transferable cross-modal ECG features with clinically relevant knowledge from report supervision.
Distribution shift refers to scenarios where the test set's ECGs come from a different distribution (often caused by different data sources) than the training set. Among them the authors focused on the most common distribution shift in healthcare data: domain shift (covariate shift), where the label space are shared but input distributions vary. To evaluate the generalizability and robustness of the learned ECG representation across different sources, linear probing with eSSL methods and zero-shot MERL under domain shifts was experimented: training on one dataset (the 'source domain') and testing on another (the 'target domain'), which has categories in common with the source domain. The results clearly showed that the MERL framework with CKEPE was way more robust than the other eSSL techniques.
To understand the framework better and to see what works and what doesn't, some ablation experiments were performed.
Loss Function
The most interesting one for me was that they found adding unimodal alignment loss term improves the performance of their model. This suggests that UMA enhances the model's ability to learn ECG representation in the latent space, benefiting downstream tasks.
Text Encoder
The performance of various text encoders on zero-shot and linear probing tasks is evaluated. Med-CPT achieves the highest performance with scores of 75.24% for zero-shot and 65.96% for linear probing. This superior performance is attributed to its discriminative and representative text embeddings, as Med-CPT is pre-trained on a text contrastive learning task. In contrast, other encoders, which are pre-trained on masked language modeling tasks, yield suboptimal outcomes with lower scores.
Clinical Knowledge Database
The impact of using web and local clinical knowledge databases on zero-shot performance is significant. Eliminating the web database results in a notable performance reduction to 72.17%, underscoring its importance due to its larger scale. Similarly, removing the local SCP Statement database reduces performance to 73.62%. The combined use of both databases yields the highest performance of 75.24%, highlighting their complementary benefits in enhancing zero-shot classification accuracy.
Data Augmentation Strategies
Four different data augmentation strategies are evaluated for their impact on zero-shot and linear probing tasks. Naive data augmentations like Cutout, Drop, and Gaussian noise distort semantic information, leading to lower quality representations. Latent space augmentation, which avoids semantic distortion, demonstrates superior performance with a zero-shot score of 75.24% and a linear probing score of 65.96%. This highlights the effectiveness of latent space augmentation over traditional signal-level augmentations.
Dropout Ratio
The effects of different dropout ratios on latent space augmentation are examined. A dropout ratio of 0.1 yields the best results, with a zero-shot performance of 75.24% and a linear probing performance of 65.96%. Both higher and lower dropout ratios lead to decreased performance, emphasizing that 0.1 is the optimal dropout ratio for maintaining representation quality and robustness during training.
Feature Extractors for ECG
The performance comparison between CNN-based ResNet18 and transformer-based ViT-Tiny as ECG feature extractors reveals that the CNN-based ResNet18 outperforms ViT-Tiny. The zero-shot and linear probing scores for ResNet18 are 75.24% and 65.96%, respectively, compared to ViT-Tiny's scores of 73.54% and 63.53%. This suggests that CNNs are better suited for capturing ECG patterns, while the tokenization strategy in transformers may lead to information loss, affecting performance.
You can find the complete code for this project here.
In this article I explained the paper Zero-shot ECG classification using multimodal learning and test-time clinical knowledge enhancement.
It introduced MERL, a scalable and effective multimodal ECG learning framework that incorporated CMA and UMA alignment strategies during training, and CKEPE, a strategy for customizing prompts, during testing. CKEPE leveraged the capabilities of LLMs to extract and restructure clinical knowledge from a provided database, boosting zero-shot MERL to outperform eSSL with linear probing in classification tasks.
If you are an AI and healthcare professional, you would definitely love this paper. That is the reason I had to write an article about it. I hope it helped you 🙂
Thank you for reading the article!
Connect with me on LinkedIn 😊
Follow me on Medium!
References:
[1] Zero-Shot ECG Classification with Multimodal Learning and Test-time Clinical Knowledge Enhancement