PM 2 F 2 N: Patient Multi-view Multi-modal Feature Fusion Networks for Clinical Outcome Prediction

,


Introduction
With the development of information technology in medical area, an increasing number of devices are used for monitoring patients.And a large number of data is stored as electronic health records (EHR), which contain numerical results of physical examination in time series and clinical notes in text for patients' relevant information.The multi-type data can be utilized to predict the condition of patients, which can help in managing the resources in hospitals.Most previous works focused on modeling the problem using the time series data recorded by medical instruments (Ghassemi et al., 2015;Xu et al., 2018).However, the time series data gathered from medical devices only reflects physical status of patients in a one-sided way.Medical professionals need to utilize their expertise to analyze patients' data and make the diagnosis.The important analyses to patients' data are recorded in EHR as clinical notes.
More recent work applied natural language processing (NLP) methods to take full advantage of medical information in clinical notes for prediction tasks (Boag et al., 2018;Lee et al., 2020;van Aken et al., 2021).They utilized pre-trained language models to extract text features of clinical notes and fed them into recurrent or convolution neural networks for clinical outcome prediction.Further more, considering to combine time series data with clinical notes for improved prediction on clinical outcome, some recent work proposed multi-modal learning methods to jointly model the two kinds of data (Khadanga et al., 2019;Bardak and Tan, 2021a;Deznabi et al., 2021).They used sequence models to extract features of time series and clinical notes respectively, and concatenated them for predicting clinical outcome.However, the existing methods do not consider that the features of time series data and clinical notes fuse different parts of each other with various weights.Besides, the multi-modal features of a single patient is not sufficient for clinical outcome prediction, and the medical correlation between patients has not been exploited for this task.
To overcome the above disadvantages of the existing methods, we propose the patient multi-view multi-modal feature fusion networks (PM 2 F 2 N) 1 for clinical outcome prediction.The model enhances the multi-modal feature fusion ability in two views.Firstly, from the patient inner view, we use the co-attention (Lu et al., 2016) module to enhance the fine-grained feature interaction between time series data and clinical notes.The co-attention module allows our model to attend to important parts of time series data as well as correlated medical information of clinical notes.Secondly, from the patient outer view, other patients' information is useful to predict the status of the observed one.We construct the patient correlation graph based on the structural medical information extracted from clinical notes, and fuse patients' multi-modal features by graph neural networks (GNN).With the multi-modal feature fusion from different views, our model can gain better generalization ability to predict clinical outcome.The contributions of this manuscript can be summarized as follows: 1. We analyze the disadvantages of the existing methods for clinical outcome prediction.To improve the ability to fuse the multi-modal features from different views, we propose the patient multi-view multi-modal feature fusion networks.
2. From the patient inner view, we extend the coattention module to enhance the fine-grained feature fusion between time series data and clinical notes.Besides, from the outer view, we exploit the patient correlation graph to aggregate the multi-modal features between patients.
3. We evaluate the effectiveness of the proposed model on MIMIC-III benchmark.The exper-imental results demonstrate that our model outperforms the baseline approaches.And the further analysis to multi-modal features also shows the superiority of our model.

Time Series for Clinical Outcome Prediction
The earlier works on mortality prediction designed hand-crafted features and used traditional machine learning methods like logistic regression with various severity scores (Vincent et al., 1996).With the progress of the deep learning, the sequence models, such as: long-short term memory networks (LSTM) (Hochreiter and Schmidhuber, 1997) and gated recurrent units (GRU) (Cho et al., 2014), are utilized to tackle with time series data for clinical outcome prediction.Besides, some researchers exploited irregular sampling of the data over time in their prediction models (Zhang et al., 2021b).Furthermore, the self-attention mechanism is also used to capture the dependencies within the time series data for clinical outcome prediction (Song et al., 2018;Ma et al., 2020).

Clinical Notes for Clinical Outcome Prediction
Considering the time series data is limited in explicit medical information, some works focused on using clinical notes for outcome prediction.They utilized the pre-trained word embeddings (Zhang et al., 2019) as text features of clinical notes, and fed them into recurrent neural networks (RNN) or convolution neural networks (CNN) to extract hidden features for predicting results (Ghorbani et al., 2020).Besides, the external medical knowledge is useful to predict the physical status of patients.The clinical outcome pre-training method was proposed to integrate knowledge from multiple patient outcomes (van Aken et al., 2021).

Multi-modal Learning for Clinical Outcome Prediction
With the development of multi-modal learning, the above methods are limited in fusing various sources of available data when predicting medical outcomes.And the data of every modal can be enhanced with each other in multi-modal learning.The multi-modal learning for time series data and clinical notes showed the effectiveness on clinical outcome prediction (Khadanga et al., 2019;Deznabi et al., 2021).They utilized RNN to extract hidden representations of time series data and CNN to acquire ones of clinical notes.The two hidden features were then concatenated and fed into feed-forward neural networks (FFNN) for predicting results.Besides, to model the robust representations of patients' multi-type data in EHR, the supervised deep patient representation learning framework was proposed for clinical outcome prediction (Zhang et al., 2021a).To make use of sparse medical information in clinical notes, the named entity recognition (NER) model was utilized to extract entities in texts and the representations of them were introduced into multi-modal learning model for making predictions (Bardak and Tan, 2021a).The existing methods do not consider to evaluate the status of patients from different aspects.Therefore, we propose PM 2 F 2 N model to fuse multi-modal features from various views for clinical outcome prediction with better generalization ability.

Model
We introduce the notations about clinical outcome prediction before getting into the details of the proposed model.The training set with N s samples is denoted as {(X (i) , C (i) , y (i) )} Ns i=1 , where X (i) and C (i) are i-th patient's time series data and clinical note respectively, and y (i) is the task-defined label.Given a time series with N t time steps and N v observed variables, the patient's vital signals can be formulated as X = {x 1 , x 2 , . . ., x Nt } ∈ R Nt×Nv .We denote the clinical note with N w words as C = {c 1 , c 2 , . . ., c Nw }.After pre-processing the multi-modal data, we feed them into the proposed model.
The PM 2 F 2 N model is shown in Figure 2.For time series data, we utilize the bidirectional GRU to extract hidden representations.Considering to acquire the multi-grained features of clinical notes, we apply the NER model to extract medical entities as local features and use the term frequency-inverse document frequency (TF-IDF) method to extract global features of clinical notes.To combine the entity representations of clinical notes with hidden representations of time series data in a fine-grained way, we exploit the co-attention module to acquire the multi-modal fusion features with various attention weights.Based on the medical information of different patients, we build the patient correlation graph and exploit it to aggregate multi-modal features of various neighbors via GNN.The concatenation of global features of clinical notes, last hidden features of time series data and aggregation multi-modal features is fed into FFNN for outcome prediction.

Multi-modal Feature Extraction
Given the multi-modal data as input, we need to pre-processe them and map them into the dense representations for deep neural networks as shown in Figure 2. We denote time series data which has N t time steps and N v observed variables as X = {x 1 , x 2 , . . ., x Nt } ∈ R Nt×Nv .With the impressive performances of RNN, GRU and LSTM are utilized to extract the hidden representations of sequence data.Considering to capture the context information in forward and backward directions, we utilize the bidirectional GRU (BiGRU) to acquire the hidden features of time series data X ( Bardak and Tan, 2021a).The extraction process is simplified as BiGRU(X; where N h is the dimension number of hidden feature vector and θ 1 is the trainable parameters of BiGRU. The clinical notes contain detailed information about patients and medical knowledge implicated in inference of doctors.Considering that clinical notes may contain redundant information, we need to extract the representative features to highlight critical patient information.Therefore, we propose to extract the multi-grained features of clinical notes.To make full use of unstructured clinical note C, we utilize the TF-IDF to extract the global feature vector.With the advantage of TF-IDF, the important tokens in clinical notes can be represented by the global feature vector.However, the dimension of global feature vector is too high to represent the patient with a tight way and fit all into the memory.We then apply principal component analysis (PCA) to reduce the dimension of global feature vector and the dimension-reduced global feature of clinical note is defined as C g ∈ R Ng where N g is the dimension number of global feature vector.
Besides, there are various medical information defined as entities including: diseases, drugs, dosage and so on, in clinical notes (Kormilitzin et al., 2021).The structural medical knowledge, known as entities, is the most important information to represent the status of patients.The raw clinical notes contain lots of redundant free-text

Multi-modal Feature Fusion with Multi-view
There are different views to evaluate the physical status of patients.From the inner view of the observed patient, the doctor analyzes the multi-modal data to make diagnosis.Based on accumulated clinical experience, the doctor can also dig into the correlation between patients to provide the diagnostic result to the observed patient.From the outer view, therefore, the multi-modal data of other patients, which are correlated with the observed one in medical knowledge, is also beneficial to the diagnostic results.With the target to enhance the representation ability of our model, we propose to improve the multi-modal feature fusion strategy in two different views.

Feature Fusion with Inner View
The where W s ∈ R N d ×N d is the trainable weight in the module.The shared feature S is used to calculate the correlation between time series and medical entity features.Firstly, the multi-modal fusion features of time series data and clinical note are calculated as: noted as Ĥ and Ĉb .

Feature Fusion with Outer View
To analyze the physical status of the observed patient, the relevant patients' information is worthing referring to.The patients with the approximate physiological conditions are represented with similar multi-modal features.Therefore, we make an effort to construct the correlation graph between patients and aggregate the multi-modal features of them by their neighbors with medical knowledge relevance.Given the clinical notes {C (i) } Ns i=1 in training set, we have acquired the medical entity set {C (i) e } Ns i=1 of them by Med7 model.The patient correlation graph (PCG) A ∈ R Ns×Ns is initialized as an identity matrix.And the elements {A ij |i, j ∈ {1, 2, . . ., N s }} in the PCG are the correlation degree between i-th patient and j-th one.Considering that the medical entity is the most important information to represent the patient, we exploit it to evaluate the correlation degree between each patient as shown in Figure 3.The jaccard similarity is the metric to evaluate the correlation of two sets, and the correlation degrees are calculated as follows: We concatenate the original extracted multi-modal features and the fusion ones as the patients' features P = {p (i) } Ns i=1 where the i-th patient's multi-modal feature is calculated as p + b p , and W p and b p are trainable weights in the model.
To update the observed multi-modal features via the relevant patients, we utilize the graph convolution networks (GCN) (Kipf and Welling, 2017)   aggregate ones of neighbors.The calculation of the aggregation multi-modal features is simplified as P = σ (APW g + b g ) where W g and b g are trainable weights in GCN module, and σ is the nonlinear function.The patient multi-modal fusion feature with outer view is denoted as P = {p (i) } Ns i=1 that contains various correlated ones.

Training Procedure
After acquiring the multi-modal fusion feature with multi-view, we utilize it to predict the target probabilities.The concatenation of multimodal fusion features with multi-view and original extracted features is feed into the FFNN.The prediction probabilities are calculated as ŷ(i) = FFNN h (i) Nt ; C g ; P(i) ; θ 2 where θ 2 is the trainable weights in the FFNN module.To solve the classification task, we utilize the cross-entropy loss as follows: We feed the multi-modal data into the model and acquire the loss according to Equation 2. To train the parameter weights of the model, we use the stochastic gradient descent (SGD) method to update them according to the calculated loss.

Dataset and Experiment Settings
We compare the proposed model with the existing methods on the medical benchmark dataset MIMIC-III (Johnson et al., 2016).The dataset contains the multi-type data collected from the real scenario including vital signals, clinical notes, ICD-9 code and so on.We follow the previous work (Bardak and Tan, 2021a) to extract the time series data and clinical notes from the raw dataset with the publicly available tool MIMIC-Extract (Wang et al., 2020).The detailed statistical information of the dataset is shown in Table 1.The dataset is always used for two common targets: mortality and length of stay (LOS).And there are four binary classification tasks analyzed by the above works as follows: 1. In-hospital Mortality: This task targets to predict whether a patient dies before being discharged.
2. In-ICU Mortality: This task is defined to detect patients who are physically declining and predict the mortality of them within 24 hours.
3. LOS >3: This task targets to predict whether a patient stays in the ICU longer than 3 days.
4. LOS >7: This task is defined to detect patients who stay in the ICU longer than 7 days.
After extracting the dataset, we utilize the Python package fancyimpute2 to impute the missing values in the time series data.We feed the clinical notes into the Med7 model to extract the medical entities and utilize the BioBERT-Large (Lee et al., 2020) version of the language model BERT to extract text features of the entities.The dimension numbers N d and N k of hidden features in co-attention module are set to 128, and the others are set to 256 in our model.We set the dropout rate and learning rate to be 0.5 and 0.001 respectively.During the training process, we firstly train the model on the training set 300 epochs at most and test it on the development set.According to the early stopping strategy, we stop training the model when the loss on the development set does not decrease within 20 epochs.We use two different metrics including AUROC and AUPRC to evaluate the models on the imbalance tasks.All experiments are accelerated by a single NVIDIA GTX A6000 device.

Compared Methods
We compare the proposed model with the existing machine learning methods.The models proposed by (Khadanga et al., 2019;Deznabi et al., 2021) were designed to combine the time series data with clinical notes with simple feature fusion strategy for outcome prediction.Besides, a new calibrated random forest (CaliForest) utilizing out-of-bag samples was proposed for the risk prediction (Park and Ho, 2020).Taking the structural medical information into account, the models proposed by (Bardak and Tan, 2021b,a) were implemented to combine the time series data with important medical mentions for clinical outcome prediction.The robust representations of patients' multi-model data in EHR are critical to the downstream tasks and the supervised deep patient representation learning framework was proposed for outcome prediction (Zhang et al., 2021a).The label aware attention mechanism was introduced into the multi-modal learning method (Yang et al., 2021) for the prediction task.

Experimental Results
We compare PM 2 F 2 N with the baseline methods on four classification tasks.The detailed experimental results on MIMIC-III are shown in Table 2. Our model can always achieve the best results on four tasks when compared with baseline methods.And the AUROC and AUPRC scores of the proposed model increase by 1.1% ∼ 3.7% and 0.4% ∼ 10.5% over baselines on four tasks respectively.Compared with traditional method CaliForest (Park and Ho, 2020), the deep learning models can always gain better results than it on most classification tasks.Our model can outperform the multi-modal learning methods with simple feature fusion strategy (Khadanga et al., 2019;Deznabi et al., 2021) because the proposed one takes full advantage of patient multi-view multi-modal feature fusion.Although the models by (Bardak and Tan, 2021b,a) utilized the medical entities in clinical notes, they did not model the fine-grained features between multi-modal data.And the model by (Yang et al., 2021) incorporated the label information to enhance the text features of clinical notes and ignored fine-grained feature fusion.Our model gains better results over them with the use of co-attention module for effectively modeling multimodal fusion features.Besides, the representation learning method (Zhang et al., 2021a) is beneficial to downstream risk prediction task.However, the method did not take the patient correlation in medical knowledge into account and model the relevant multi-modal features.Our model exploits the structural medical information for constructing patient correlation graph and fuses the multi-modal features by GCN based on the graph.Therefore, the proposed model gains better generalization ability for clinical outcome prediction.

Further Discussion
To dig into the model, we conduct the detailed analysis for presenting it in different aspects.The ablation study is performed to demonstrate the effectiveness of the different feature fusion strategies proposed in our model.Besides, to verify the effectiveness of the patient correlation graph, we compare the performances of the tasks that are conducted on the adjacency matrixes filled with different values.Eventually, we visualize the multimodal features extracted from the proposed model for presenting the usefulness of the patient correlation information in the feature fusion aspect.

Ablation Study
As shown in Table 3, we conduct the ablation study to present the effectiveness of the proposed multimodal feature fusion strategies.We utilize the single modal data (TS), multi-modal data (TS + CN) to train clinical outcome prediction models respectively as the base comparison methods.It proves the effectiveness of multi-modal learning that the model trained with multi-modal data achieves better results than that with single-modal data.When the patient correlation graph (PCG) is introduced into the base multi-modal learning method, the results on the four tasks are improved to vary degrees.The multi-modal feature fusion with outer view can aggregate that of various patients and improve the generalization ability of the proposed model for clinical outcome prediction.The model incorporated with co-attention (CA) is the proposed model PM 2 F 2 N and gets vary improvements on the four tasks.With the advantage of CA, our model can fuse the multi-modal features in a fine-grained way.

Effect of Patient Correlation Graph
As shown in Figure 4, we conduct the comparison experiments to demonstrate the effectiveness of the proposed patient correlation graph (PCG).The proposed model is fed with two distinct adjacency matrices filled with all 0s and 1s to replace the PCG.The "Adj=0" model utilizes the adjacency matrix filled with all 0s to disentangle GCN from our model as a baseline.And the "Adj=1" model exploits the adjacency matrix filled with all 1s to verify the effect of the patient correlation degrees.Compared with baseline "Adj=0" model, the "Adj=1" model gets various drops on AUROC and AUPRC while our model gains 0.5% ∼ 1.3%

Visualization Analysis
To verify the effectiveness of patient correlation graph (PCG) to the multi-modal fusion feature intuitively, we visualize the learned features extracted from the models with and without patient correlation graph as shown in Figure 5.We focus on the LOS >3 task because of its balanced class distribution.After training the models, we utilize them to acquire the multi-modal features of samples in the test set.We visualize the patient multi-modal fusion features in Figure 5 where the dimension is reduce to two by t-SNE.Further more, we also select the same group of patients to highlight their feature points and circle them.As the whole patients observed, the multi-modal features of them with same label that are learned with PCG are more clustered.The selected patients' features learnt with PCG are clustered into two groups with a clear boundary, but that without PCG are scattered and intertwined.
The comparison between the two results demonstrate the effectiveness of the patient correlation graph which connects the multi-modal features of relevant patients.

Conclusion
In this paper, we analyze the disadvantages of existing multi-modal learning methods for clinical outcome prediction.To enhance the multi-modal feature in different views, we propose the patient multi-view multi-modal feature fusion networks (PM 2 F 2 N) for the task.From the inner view, we extend the co-attention module to fuse the features of the time series data and structural medical knowledge in a fine-grained way.From the outer view, we exploit the correlation between patients to aggregate the multi-modal features between the similar patients.With the multi-view multi-modal feature fusion strategy, the proposed model can learn the general patient representations for clinical outcome prediction.Compared with the existing methods, our model can gain the best results on benchmark dataset MIMIC-III.And the further discussion including ablation study, effect of PCG and visualization analysis, verifies the effectiveness of the proposed strategy.Considering the heterogeneity of patients, we will try to adapt the heterogeneous graphs for modeling the correlation between patients in the future.

Limitations
The proposed model is limited to feeding the whole patient multimodal data into it and utilizing large memory to calculate aggregation multimodal features by GCN layer.Besides, the scalability of patient correlation graph (PCG) is poor because the PCG should be reconstructed when the new patients are added into the original patient set.

Figure 1 :
Figure 1: There are two views to analyze the observed patient.The inner view is to focus on the medical data of the observed patient.And the outer view is to exploit the medical correlation between patients for the observed one.

Figure 2 :
Figure 2: The patient multi-view multi-modal feature fusion networks (PM 2 F 2 N) for clinical outcome prediction.
time series data X contains various physiological signals changing over time.And the medical entity set C e includes different medical knowledge representing the condition of patient.Some of the medical information in clinical note is relevant to the physical signals at certain times.For example, the observed patient was treated with certain drugs and the physical signals would change in some aspect.To capture the fine-grained correlation between multi-modal data, we propose to exploit co-attention module for fusing the multimodal features.Although the co-attention achieved significant success in visual question answering area (Lu et al., 2016), this is the first time to expand it to the medical multi-modal data mining area.Given the extracted features H of time series data X and that C b of clinical note C, we unify the dimension number of both as: H = HW a and C b = C b W b where W a ∈ R N h ×N d and W b ∈ R N b ×N d are trainable weights.To calculate the correlation degree between time series and medical entities, the shared feature space S ∈ R Nt×Ne is defined as S = tanh H W s C T b

Figure 3 :
Figure 3: The detailed construction of the patient correlation graph.The degree of correlation between patients is defined as the jaccard similarity between medical entity sets in each patient's clinical note.

Figure 5 :
Figure5: The t-SNE visualization results of the multimodal features extracted from the models with and without patient correlation graph respectively.We evaluate the models on LOS >3 task.And the highlighted points represent the same group of patients. to

Table 1 :
The statistical information of the MIMIC-III dataset extracted by MIMIC-Extract."T.S." is short for "time-series".

Table 2 :
The experimental results on four clinical outcome prediction tasks in macro-averaged % AUROC and % AUPRC.We run the experiments 5 times with different random seeds and report the average results.Our model outperforms the baseline methods on four tasks.

Table 3 :
The results for ablation study."TS" and "CN" are short for time series data and clinical notes respectively."PCG" represents the patient correlation graph which is introduced into our model."CA" is a co-attention module to fuse features of time series and clinical notes.