Scientific Papers

CAT-DTI: cross-attention and Transformer network with domain adaptation for drug-target interaction prediction | BMC Bioinformatics


An overview of CAT-DTI framework is illustrated in Fig. 1a. Given drug SMILES and protein amino acid sequences as input, the protein and drug embeddings are generated. Drug embedding is input into GCN to extract feature representations of drug molecules (i.e., drug feature map \(F_D\)). The protein embedding is passed to the protein feature encoder as shown in Fig. 1b, which combines the CNN and Transformer to extract the protein feature map \(F_P\), capturing local features and global context information in the protein sequence simultaneously. Next, the cross-attention module interacts protein and drug features for feature fusion to capture the interaction relationship between drugs and targets, as shown in Fig. 1c. Specifically, we swap the key and value of protein attention with those of drug attention. After obtaining the feature maps, the original features are integrated to construct the final features for both drugs and proteins. Through max-pooling and concatenation, the joint feature f for drug and protein target is produced and input into the decoder to predict DTI. To enhance the generalization performance of CAT-DTI in real-world scenarios for novel drug-target pairs, we integrate the domain adaptation module CDAN into the framework, which is employed to adapt the representations of drugs and proteins, thereby facilitating effective alignment between source and target domain distributions.

Fig. 1
figure 1

Framework of the proposed CAT-DTI. a Overview of CAT-DTI framework. b Details of the protein feature encoder. c Processes of cross-attention

GCN for drug molecular graph

Regarding the drug feature extraction process, we transform drug SMILES into a corresponding 2D molecular graph. To capture the node information within the graph, we first initialize each atom node. Each atom is denoted by a 74-dimensional integer vector that encapsulates eight distinct attributes, including the atom type, the atom degree, the number of implicit Hs, the formal charge, the number of radical electrons, the atom hybridization, the number of total Hs and whether the atom is aromatic.

The drug feature encoder transmits and aggregates information on the drug molecular structure through a three-layer GCN, thereby achieving extraction and representation of drug feature. In each layer of GCN operation, each row of the drug representation represents an aggregated representation of adjacent atomic nodes in the drug molecule. Each GCN layer uses the information of neighboring atomic nodes to update the feature representation of each atomic node, allowing the model to effectively capture the correlation information between neighboring atomic nodes. We retain node-level drug representations for subsequent explicit learning of interactions with protein fragments. We set the maximum number of nodes in the graph to be \(m_d\). Therefore, the node feature matrix of each graph is denoted as \(M_d\in \mathbb {R}^{m_d\times 74}\). Furthermore, we employ a simple linear transformation to establish \(F_{d}=M_{d}W_{o}^{\top }\), resulting in a real-valued dense matrix \(F_d\in \mathbb {R}^{m_d\times D_d}\) as input features, where \(D_d\) is the drug embedding dimension. Finally, we obtain the drug feature map \(F_D\in \mathbb {R}^{m_d\times D_d}\) through the drug feature encoder, which can be expressed as:

$$\begin{aligned} H_d^{i+1}=\sigma (\text {GCN}(\widetilde{A},W_{gcn}^i,b_{gcn}^i,H_d^i)), \end{aligned}$$

(1)

where \(W_{gcn}^i\) and \(b_{gcn}^i\) are the weight matrices and bias vector of the i-th layer of GCN. \(\tilde{A}\) is the adjacency matrix with added self-connection. \(H_{d}^i\) denotes the hidden node representation of layer i with \(H_{d}^{0}=F_{d}\).

Feature encoder for protein

To enhance protein sequence feature representation and capture long-distance relationship between sequence tags, we introduce a protein feature encoder that combines CNN and Transformer. Traditional CNN may struggle with long sequences due to the limited local receptive fields, so we combine the global attention mechanism of Transformer to capture long-distance dependence in protein sequences. By fusing the local perception capabilitiy of CNN and the global attention mechanism of Transformer, our model simultaneously considers local features and global context information in protein sequences, thereby extracting more effective protein features.

It is worth noting that before the feed forward layer of Transformer, we add 1D CNN to process local information. By sliding the convolution kernel on the protein sequence, we captured the local pattern and substructure of the protein. Combined with the advantages of Transformer in handling long-range dependencies, our model achieves the fusion of local and global information in the protein feature encoding process, which is beneficial to enhance the representation of protein sequence features. In our work, a three-layer protein feature encoder is used to capture protein features, as shown in Fig. 1b, where each layer includes a multi-head self-attention, CNN and a feed-forward neural network. Specifically, the protein sequence is input to the protein feature encoder with the feature matrix \(F_p\in \mathbb {R}^{l_p\times D_p}\), where \(l_p\) is the length of the protein sequence and \(D_p\) is the protein embedding dimension. The matrices \(Q\in \mathbb {R}^{l_p\times D_p}\), \(K\in \mathbb {R}^{l_p\times D_p}\) and \(V\in \mathbb {R}^{l_p\times D_p}\) in different feature spaces based on the feature matrix \(F_p\) are generated by the linear layer as follows:

$$\begin{aligned} {\left\{ \begin{array}{ll}Q=F_p\cdot W_Q+b_Q\\ K=F_p\cdot W_K+b_K,\\ V=F_p\cdot W_V+b_V\end{array}\right. } \end{aligned}$$

(2)

where \(W_Q\in \mathbb {R}^{D_p\times D_p}\), \(W_K\in \mathbb {R}^{D_p\times D_p}\), \(W_V\in \mathbb {R}^{D_p\times D_p}\) are learnable parameter weights. \(b_Q\), \(b_K\) and \(b_V\) are bias vectors. Given Q, K and V matrices, the self-attention layer computes the attention weights as follows:

$$\begin{aligned} \text {Attention}(Q,K,V)=\text {Softmax}(\frac{Q\cdot K^\top }{\sqrt{d_k}})V, \end{aligned}$$

(3)

where \(d_k\) is the dimension of K. The output \(X_M\) of the multi-head attention layer is generated as follows:

$$\begin{aligned} \begin{aligned} X_M&=\text {MutiHead}(Q,K,V)=\text {Concat}(\text {Attention}(Q,K,V))W_M+b_M, \end{aligned} \end{aligned}$$

(4)

where \(W_M\in \mathbb {R}^{D_p\times D_p}\) is the learnable weight matrix and \(b_M\) is the bias vector.

The multi-head attention layer extracts information from diverse representation subspace, enhancing model robustness. Therefore, long-range relationships between amino acids spanning the entire sequence can be learned with self-attention weights. Additionally, the first ADD & Norm layer implements a residual connection with original protein feature matrix \(F_p\) and then follow by normalization, expressed as follows:

$$\begin{aligned} X_{AN}=\text {LayerNorm}(F_{p}+X_{M}), \end{aligned}$$

(5)

Subsequently, a three-layer CNN is inserted after the first ADD & Norm layer to extract local feature in the protein sequence:

$$\begin{aligned} X_{CNN}=\textrm{CNN}\left( X_{AN}\right) , \end{aligned}$$

(6)

After the second ADD & Norm layer, we derive the protein feature map \(F_{P}\in \mathbb {R}^{l_p \times D_p}\) as follows:

$$\begin{aligned} F_P=\text {LayerNorm}(X_{CNN}+X_{AN}), \end{aligned}$$

(7)

Cross-attention module

After obtaining the feature maps for drugs and proteins through the feature encoder, we introduce a cross-attention module to effectively model the interaction between drugs and proteins, thereby capturing enhanced representations of their interaction and provides more reliable feature representation for DTI prediction. By performing two-way information interaction between the key and value of protein attention and the key and value of drug attention, the information exchange and association between drug and protein is realized, thus capturing the interaction features between drug and protein target. In this process, protein features can adjust their own expression by attention weights of drug features, and vice versa. Such an interaction and adjustment mechanism enable the cross-attention module to promote information flow across feature maps, effectively fuse drug and protein features, and extract more comprehensive DTI feature representation. The cross-attention module is depicted in Fig. 1c and primarily consists of drug and protein attention.

In this section, we set \(D_e=D_d=D_p\). For drugs, the drug feature map \(F_D\) is passed through the linear layer to calculate the drug query vector \(Q_D^i\in \mathbb {R}^{m_d\times d_{head}}\), and then the protein feature map \(F_P\) is obtained through a linear layer, which is further calculated as the drug key vector \(K_D^i\in \mathbb {R}^{l{_p\times d}_\textit{head}}\) and value vector \(V_D^i\in \mathbb {R}^{l{_p\times d}_\textit{head}}\). The query, key and value for the drug are obtained as follows:

$$\begin{aligned} {\left\{ \begin{array}{ll}Q_D^i=F_D\cdot W_q^i\\ K_D^i=F_P\cdot W_k^i ,\\ V_D^i=F_P\cdot W_\nu ^i\end{array}\right. } \end{aligned}$$

(8)

where \(\begin{aligned}W_q^i, W_k^i, W_\nu ^i\in \mathbb {R}^{D_e\times d_{head}}\end{aligned}\) are different weight matrices in the linear layer and \(d_{head}=D_{e}/heads\) is the channel dimension. \(i=1,2,\cdots ,heads\), where heads are the number of attention heads.

Protein attention follows a process similar to drug attention. The protein feature map \(F_P\) is input into the linear layer to calculate the protein query vector \(Q_P^i\in \mathbb {R}^{l{_p\times d}_\textit{head}}\), and then the drug feature map is taken to generate the protein key vector \(K_P^i\in \mathbb {R}^{m{_d\times d}_\textit{head}}\) and protein value vector \(V_P^i\in \mathbb {R}^{m{_d\times d}_\textit{head}}\). The queries, keys and values of proteins are calculated by the following formulas:

$$\begin{aligned} \left. \left\{ \begin{aligned}Q_P^i&=F_P\cdot W_q^i\\K_P^i&=F_D\cdot W_k^i\\V_P^i&=F_D\cdot W_v^i\end{aligned}\right. \right. , \end{aligned}$$

(9)

where the weight matrices \(W_q^i\in \mathbb {R}^{D{_e}\times d_\text {head}}\), \(W_k^i\in \mathbb {R}^{D{_e}\times d_\text {head}}\) and \(W_v^i\in \mathbb {R}^{D{_e}\times d_\text {head}}\) share the same weights as drug attention. Through the application of a softmax function, the drug and protein attention matrices are computed as:

$$\begin{aligned}{} & {} A_D^i=\textrm{Softmax}\left( \frac{Q_D^i\cdot K_D^i {^\top }}{\sqrt{d_{K_D^i}}}\right) , \end{aligned}$$

(10)

$$\begin{aligned}{} & {} A_P^i=\textrm{Softmax}\left( \frac{Q_P^i\cdot K_P^i {^\top }}{\sqrt{d_{K_P^i}}}\right) , \end{aligned}$$

(11)

where \(d_{K_D^i}=d_{K_P^i}=d_{head}\) is the dimension of K for drug and protein. The drug/protein feature map for each head is obtained by multiplying the drug/protein attention matrix of each attention head with the corresponding drug/protein value matrix. Subsequently, the drug/protein feature maps of all attention heads are concatenated in the channel dimension and fed into the linear layer to obtain the final drug feature representation \(Z_P\in \mathbb {R}^{l_p\times D_p}\) and protein feature map \(Z_D\in \mathbb {R}^{m_d\times D_d}\) received attention:

$$\begin{aligned}{} & {} Z_D=\text {Concat}(A_D^i\times V_D^i)\times W_Z, \end{aligned}$$

(12)

$$\begin{aligned}{} & {} Z_P=\text {Concat}(A_P^i\times V_P^i)\times W_Z, \end{aligned}$$

(13)

where \(i=1,2,\cdots ,heads\) and \(W_Z\in \mathbb {R}^{D_e \times D_e}\) is the shared weight matrix.

Next, the feature maps of interest are combined with the original feature maps to obtain the final drug feature map \(F_{ZD}\in \mathbb {R}^{m_d\times D_d}\) and protein feature map \(F_{ZP}\in \mathbb {R}^{l_p\times D_p}\):

$$\begin{aligned}{} & {} F_{ZD}=0.5Z_D+0.5F_D, \end{aligned}$$

(14)

$$\begin{aligned}{} & {} F_{ZP}=0.5Z_P+0.5F_P, \end{aligned}$$

(15)

The drug and protein feature maps are downsampled by using a global max-pooling operation to generate one-dimensional drug feature vector \(d_{mp}\in \mathbb {R}^{D_d}\) and protein feature vector \(p_{mp}\in \mathbb {R}^{D_p}\):

$$\begin{aligned}{} & {} d_{mp}=\text {Maxpooling}(F_{ZD}), \end{aligned}$$

(16)

$$\begin{aligned}{} & {} p_{mp}=\text {Maxpooling}(F_{ZP}), \end{aligned}$$

(17)

Finally, we concatenate \(d_{mp}\) and \(p_{mp}\) to obtain the joint feature representation \(f\in \mathbb {R}^{2D_e}\):

$$\begin{aligned} f=\textrm{Concat}(d_{mp},p_{mp}), \end{aligned}$$

(18)

Drug-target interaction prediction

In order to predict the DTI probability, we input the joint representation f into the decoder, which consists of a fully connected classification layer. Finally, the DTI probability p is generated as follows:

$$\begin{aligned} p=\sigma \left( Wf+b\right) , \end{aligned}$$

(19)

where W and b are learnable weight matrix and bias vector.

During model training, we employ backpropagation to concurrently optimize the learnable parameters. Our objective in training is to minimize the cross-entropy loss function:

$$\begin{aligned} \mathcal {L}=-\sum _i\left( y_i\log {(p_i)}+(1-y_i)\log {(1-p_i)}\right) +\frac{1}{2}\lambda ||\theta ||_2^2, \end{aligned}$$

(20)

where \(y_i\) denotes the ground-truth label of the i-th drug-target pair. \(p_i\) represents DTI prediction score predicted by the model. \(\theta\) is the set of learnable weight matrices and bias vectors and \(\lambda\) is a hyperparameter for L2 regularization to prevent overfitting.

Cross-domain adaptation enhances generalization

Deep learning models show excellent performance on similar data (i.e., in-domain) that is distributed with the training data. However, the performance on different data with different distributions (i.e., cross-domain) is not satisfactory. To this end, we employ  the CDAN module to improve the generalization ability of CAT-DTI model from a source domain rich in labeled data to a target domain containing only unlabeled data. Figure 2 shows the framework after integrating the CDAN module into CAT-DTI (i.e., \({\text {CAT-DTI}}_{\textrm{CDAN}}\)), which consists of three key components: Feature Extractor \(F\left( *\right)\), Decoder \(G\left( *\right)\) and Discriminator \(D\left( *\right)\).

Fig. 2
figure 2

Diagram of cross-domain adaptation process. CDAN is a domain adaptation technique designed to address domain shift challenges with different distributions. We utilize CDAN to integrate the joint representation f of the source and target domain, along with classifier prediction g into the joint conditional representation distinguished by the discriminator. The discriminator is structured as a three-layer fully connected network with the specific goal of distinguishing the target domain from source domain by minimizing domain classification error

On the cross-domain task, given \(N_S\) labeled drug-target pairs \(P_S=\{(x_s^i,y_s^i)\}_{N_S}^{i=1}\) in the source domain and \(N_T\) unlabeled drug-target pairs \(P_t=\{(x_t^i)\}_{N_T}^{i=1}\) in the target domain. We rely on CDAN to adjust the distribution of samples to optimize cross-domain prediction performance. The feature extractor \(F\left( *\right)\) is the drug and protein feature encoder together with the cross-attention module to generate a joint representation of the input domain data, namely \(f_s^i=F(x_s^i)\) and \(f_t^j=F(x_t^j)\). For the decoder \(G\left( *\right)\), we employ a fully connected classification layer and follow a softmax function as \(G\left( *\right)\) to obtain predicted classification results \(g_s^i=G(f_s^i)\in \mathbb {R}^2\) and \(g_t^j=G(f_t^j)\in \mathbb {R}^2\). Subsequently, the joint representation f and the classifier prediction g are embedded into a joint conditional representation \(c\in \mathbb {R}^{2D_e}\), which is defined as follows:

$$\begin{aligned} c=\text {FLATTEN}(f\otimes g), \end{aligned}$$

(21)

where \(\text {FLATTEN}\) performs a flattening operation on the outer product of the f and g vectors and \(\otimes\) is the outer product.

Adhering to CDAN principles, we employ a domain discriminator \(D\left( *\right)\) to align the joint representation f and predicted classification distribution g of the source and target domains. \(D\left( *\right)\) is a domain discriminator composed of a three-layer FCN that learns to distinguish whether a joint conditional representation c originates from the source or target domain. \(F\left( *\right)\) and \(G\left( *\right)\) are trained to minimize the cross-entropy loss \(\mathcal {L}\) of the source domain with source label information, generating a joint conditional representation c that confuses the discriminator \(D\left( *\right)\). In the cross-domain task, we utilize two losses: one for optimizing classification prediction and the other for optimizing the distribution alignment of the source and target domain:

$$\begin{aligned}{} & {} \mathcal {L}_S(F,G)=\mathbb {E}_{(x_s^i,y_s^i)\thicksim P_s}\mathcal {L}\left( G\left( F\left( x_s^i\right) \right) ,y_s^i\right) , \end{aligned}$$

(22)

$$\begin{aligned}{} & {} \begin{aligned} \mathcal {L}_{ad\nu }(F,G,D)&=\mathbb {E}_{x_t^i\thicksim P_t}\log\left( 1-D\left( f_t^i,g_t^i\right) \right) +\mathbb {E}_{x_s^j\thicksim P_s}\log(D(f_{s}^j,g_s^j)), \end{aligned} \end{aligned}$$

(23)

where \(\mathcal {L}_S\) is the cross-entropy loss on the labeled source domain and \(\mathcal {L}_{ad\nu }\) is the adversarial loss for the domain discriminator.

The optimization problem is written as a minimax paradigm:

$$\begin{aligned} \max _{\begin{array}{c}D\end{array}}\min _{\begin{array}{c}F,G\end{array}}\mathcal {L}_S(F,G)-\omega \mathcal {L}_{ad\nu }(F,G,D), \end{aligned}$$

(24)

where \(\omega\) is a hyper parameter for weighting \(\mathcal {L}_{ad\nu }\). By introducing adversarial training in \(\mathcal {L}_{ad\nu }\), the difference in data distribution between the source domain and target domain is reduced, thereby enhancing the generalization ability of cross-domain prediction.



Source link