Whole Transcriptome Co-expression Completion Tutorial

Use the coxformer.src.embedding module for graph neural network training to predict gene co-expression networks.

1. Import Dependencies

[ ]:
import os
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import pickle
import matplotlib.pyplot as plt
from tqdm import tqdm
import gc
from torch.optim.lr_scheduler import CosineAnnealingLR
import warnings
from tensorboardX import SummaryWriter
import time

from coxformer.embedding.data import CoxformerDataset, split_observed_coexpression_edges
from coxformer.embedding.model import CoxformerGCN
from coxformer.embedding.train import CoxformerGCNTrainer
from coxformer.embedding.infer import build_completed_coexpression_matrix, print_coexpression_completion_report

2. Configuration

[3]:
project_name = 'top50_full_1e-3_L2'
top_k_edges = 50
num_neighbors = [20, 15]
batch_size = 512
only_eval = False

coexpression_file = 'data/coexpression.pkl'
correlation_file = 'data/sc_correlation_origin.pkl'
embedding_file = 'data/GPT_embedding.pkl'

print(f'Project name: {project_name}')
print(f'Top-k edges: {top_k_edges}')
print(f'Neighbor samples: {num_neighbors}')
print(f'Batch size: {batch_size}')
Project name: top50_full_1e-3_L2
Top-k edges: 50
Neighbor samples: [20, 15]
Batch size: 512

3. Directory Setup & Device Initialization

[5]:
for subdir in [f'out/{project_name}', f'runs/{project_name}']:
    if os.path.exists(subdir):
        print(f'Directory {subdir} already exists, skipping.')
    else:
        os.makedirs(subdir, exist_ok=True)
        print(f'Created directory: {subdir}')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\nUsing device: {device}')

if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f'GPU total memory: {gpu_memory:.2f} GB')
Directory out/top50_full_1e-3_L2 already exists, skipping.
Directory runs/top50_full_1e-3_L2 already exists, skipping.

Using device: cuda
GPU total memory: 44.42 GB

4. Data Loading & Graph Construction

[6]:
dataset = CoxformerDataset(coexpression_file, correlation_file, embedding_file)

(
    data,
    gene_to_idx,
    all_nodes_genes,
    common_genes,
    target_min,
    target_max,
    edge_corr_min,
    edge_corr_max,
) = dataset.create_enhanced_graph_data(top_k_edges)
Loading data...
Correlation data shape: (32101, 2)
Coexpression data shape: (18858, 2)
Embedding data shape: (43661, 2)
All node genes count (corr ∩ emb): 32016
Labeled genes count (coexp ∩ corr ∩ emb): 18184
Unlabeled genes count: 13832
Preparing node features...
Preparing correlation data...
Preparing coexpression data...
Reordering correlation adjacency matrix...
Reordering coexpression adjacency matrix...
Building graph edges, keeping top-50 edges per node...
Coexpression dictionary contains 330639672 edges
Building edges: 100%|██████████| 32016/32016 [02:02<00:00, 260.59it/s]
Total edges: 1568784
Labeled edges: 200628
Unlabeled edges: 1368156
Graph statistics:
Number of nodes: 32016
Number of edges: 1568784
Labeled edges: 200628
Average degree: 49.00
Edge feature range: [-1.0000, 1.0000]
Label range: [-1.0000, 1.0000]

5. Data Splitting

[ ]:
train_indices, val_indices, test_indices = split_observed_coexpression_edges(data)

data = data.to(device)

print('Data split summary:')
print(f'Total edges: {data.edge_index.shape[1]}')
print(f'Labeled edges: {data.label_mask.sum().item()}')
print(f'Unlabeled edges:{(~data.label_mask).sum().item()}')
print(f'Train edges: {len(train_indices)}')
print(f'Val edges: {len(val_indices)}')
print(f'Test edges: {len(test_indices)}')
Splitting labeled edges:
Total labeled edges: 200628
Data split summary:
Total edges: 1568784
Labeled edges: 200628
Unlabeled edges:1368156
Train edges: 200628
Val edges: 200628
Test edges: 200628

6. Model & Trainer Initialization

[ ]:
input_dim = data.x.shape[1]
edge_dim = data.edge_attr.shape[1]
hidden_dims = [256, 128, 64]

print('Model configuration:')
print(f'Node feature dim: {input_dim}')
print(f'Edge feature dim: {edge_dim}')
print(f'Hidden dims: {hidden_dims}')

model = CoxformerGCN(
    input_dim, hidden_dims, edge_dim=edge_dim, dropout=0.3, use_edge_features=True
)

trainer = CoxformerGCNTrainer(model, device, num_neighbors=num_neighbors, batch_size=batch_size)
Model configuration:
Node feature dim: 3072
Edge feature dim: 1
Hidden dims: [256, 128, 64]

7. Training Setup

[9]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = CosineAnnealingLR(optimizer, T_max=200, eta_min=1e-6)
criterion = nn.MSELoss()

writer = SummaryWriter(f'runs/{project_name}/training')
print(f'TensorBoard log dir: runs/{project_name}/training')

train_losses = []
val_losses = []
val_r2_scores = []
TensorBoard log dir: runs/top50_full_1e-3_L2/training

8. Training Loop

[ ]:
num_epochs = 200
training_start_time = time.time()

if not only_eval:
    for epoch in range(num_epochs):
        train_loss = trainer.train_epoch(data, optimizer, criterion, train_indices)

        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.flush()

        print(f'Epoch {epoch:03d}, Train Loss: {train_loss:.4f}')
        train_losses.append(train_loss)
        scheduler.step()

        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

torch.save(model.state_dict(), f'out/{project_name}/best_model.pth')
training_end_time = time.time()
training_duration = training_end_time - training_start_time
print(f'\nTraining time: {training_duration:.2f} s')
print(f'Model saved to: out/{project_name}/best_model.pth')

9. Test Evaluation

[ ]:
model.load_state_dict(torch.load(f'out/{project_name}/best_model.pth'))

print('Running test evaluation...')
test_loss, test_mse, test_mae, test_r2, test_pred, test_true = trainer.evaluate(
    data, criterion, test_indices
)

print('Test results:')
print(f'MSE: {test_mse:.4f}')
print(f'MAE: {test_mae:.4f}')
print(f'R2 Score: {test_r2:.4f}')

writer.add_scalar('Loss/test', test_loss, 0)
writer.add_scalar('R2/test',   test_r2,   0)
writer.flush()
writer.close()
Running test evaluation...
Evaluating:   0%|          | 0/392 [00:00<?, ?it/s]Evaluating: 100%|██████████| 392/392 [03:17<00:00,  1.98it/s]
Test results:
MSE: 0.0063
MAE: 0.0581
R2 Score: 0.6430

10. Visualization

[14]:
test_true_denorm = (test_true + 1) / 2 * (target_max - target_min) + target_min
test_pred_denorm = (test_pred + 1) / 2 * (target_max - target_min) + target_min

plt.figure(figsize=(15, 5))

# Subplot 1: time-series comparison
plt.subplot(1, 3, 1)
plt.plot(test_true_denorm, label='Ground Truth', color='blue',   alpha=0.7)
plt.plot(test_pred_denorm, label='Prediction',   color='orange', alpha=0.7)
plt.title('Prediction vs Ground Truth')
plt.xlabel('Sample Index')
plt.ylabel('Value')
plt.legend()
plt.grid(True, alpha=0.3)

# Subplot 2: scatter plot
plt.subplot(1, 3, 2)
plt.scatter(test_true_denorm, test_pred_denorm, alpha=0.5, s=1)
plt.plot(
    [test_true_denorm.min(), test_true_denorm.max()],
    [test_true_denorm.min(), test_true_denorm.max()],
    'r--', lw=2
)
plt.xlabel('Ground Truth')
plt.ylabel('Prediction')
plt.title(f'Scatter Plot (R\u00b2 = {test_r2:.3f})')
plt.grid(True, alpha=0.3)

# Subplot 3: training history
# plt.subplot(1, 3, 3)
# plt.plot(train_losses, label='Train Loss', alpha=0.7)
# if len(val_losses) > 0:
#     plt.plot(
#         range(0, len(train_losses), len(train_losses) // len(val_losses)),
#         val_losses, label='Val Loss', alpha=0.7
#     )
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.title('Training History')
# plt.legend()
# plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'out/{project_name}/predicted_coexpress.png', dpi=300, bbox_inches='tight')
plt.show()
print(f'Figure saved to: out/{project_name}/predicted_coexpress.png')
../_images/notebooks_Co-expression_completion_tutorial_20_0.png
Figure saved to: out/top50_full_1e-3_L2/predicted_coexpress.png

11. Full Inference

[19]:
print('Running inference on all edges...')
model.eval()
all_predictions = []
all_edge_correlations = []
all_has_labels = []

inference_start_time = time.time()

with torch.no_grad():
    num_edges = data.edge_index.shape[1]

    for i in tqdm(range(0, num_edges, batch_size), desc='Inferring all edges'):
        end_idx = min(i + batch_size, num_edges)
        batch_edges = data.edge_index[:, i:end_idx]
        batch_edge_attr = data.edge_attr[i:end_idx]
        batch_has_labels = data.label_mask[i:end_idx]

        unique_nodes = torch.unique(batch_edges.flatten())

        try:
            loader = trainer.create_neighbor_loader(data, input_nodes=unique_nodes, shuffle=False)
            batch_data = next(iter(loader)).to(device)

            node_embeddings = model(batch_data.x, batch_data.edge_index)
            node_mapping = {node.item(): idx for idx, node in enumerate(batch_data.n_id)}

            valid_edges = []
            valid_edge_attr = []
            valid_has_labels = []

            for j in range(batch_edges.shape[1]):
                src = batch_edges[0, j].item()
                dst = batch_edges[1, j].item()
                if src in node_mapping and dst in node_mapping:
                    valid_edges.append([node_mapping[src], node_mapping[dst]])
                    valid_edge_attr.append(batch_edge_attr[j])
                    valid_has_labels.append(batch_has_labels[j])

            if len(valid_edges) > 0:
                mapped = torch.tensor(valid_edges, dtype=torch.long).t().contiguous().to(device)
                ea_t = torch.stack(valid_edge_attr).to(device)
                pred_w = model.predict_edges(node_embeddings, mapped, ea_t)
                all_predictions.extend(pred_w.cpu().numpy())
                all_edge_correlations.extend(ea_t.cpu().numpy().flatten())
                all_has_labels.extend([x.item() for x in valid_has_labels])
            else:
                n = end_idx - i
                all_predictions.extend([0.0] * n)
                all_edge_correlations.extend([0.0] * n)
                all_has_labels.extend([False] * n)

            missing = (end_idx - i) - len(valid_edges)
            if missing > 0:
                all_predictions.extend([0.0] * missing)
                all_edge_correlations.extend([0.0] * missing)
                all_has_labels.extend([False] * missing)

        except Exception as e:
            print(f'Error in batch {i // batch_size}: {e}')
            n = end_idx - i
            all_predictions.extend([0.0] * n)
            all_edge_correlations.extend([0.0] * n)
            all_has_labels.extend([False] * n)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

inference_end_time = time.time()
inference_duration = inference_end_time - inference_start_time

all_predictions = np.array(all_predictions)
all_edge_correlations = np.array(all_edge_correlations)
all_has_labels = np.array(all_has_labels)

print(f'Inference complete (elapsed: {inference_duration:.2f} s):')
print(f'Total edges: {len(all_predictions)}')
print(f'Labeled edges: {all_has_labels.sum()}')
print(f'Unlabeled edges:{(~all_has_labels).sum()}')
print(f'Prediction range: [{all_predictions.min():.4f}, {all_predictions.max():.4f}]')
Running inference on all edges...
Inferring all edges:   0%|          | 0/3065 [00:00<?, ?it/s]Inferring all edges: 100%|██████████| 3065/3065 [22:13<00:00,  2.30it/s]
Inference complete (elapsed: 1333.43 s):
Total edges: 1568784
Labeled edges: 200625
Unlabeled edges:1368159
Prediction range: [-0.6759, 0.9469]

12. Save Intermediate Results

[ ]:

pred_denormalized = (all_predictions + 1) / 2 * (target_max - target_min) + target_min sample_size = min(20000, len(all_predictions)) sample_indices = np.random.choice(len(all_predictions), sample_size, replace=False) results = { 'model_state_dict': model.state_dict(), 'gene_to_idx': gene_to_idx, 'all_nodes_genes': all_nodes_genes[:1000] if len(all_nodes_genes) > 1000 else all_nodes_genes, 'common_genes': common_genes[:1000] if len(common_genes) > 1000 else common_genes, 'normalization_params': { 'target_min': target_min, 'target_max': target_max, 'edge_corr_min': edge_corr_min, 'edge_corr_max': edge_corr_max, }, 'test_metrics': {'mse': test_mse, 'mae': test_mae, 'r2': test_r2}, 'training_stats': { 'total_nodes': len(all_nodes_genes), 'common_nodes': len(common_genes), 'unlabeled_nodes': len(all_nodes_genes) - len(common_genes), 'total_edges': len(all_predictions), 'labeled_edges': int(all_has_labels.sum()), 'unlabeled_edges': int((~all_has_labels).sum()), }, 'config': { 'training': True, 'top_k_edges': top_k_edges, 'hidden_dims': hidden_dims, 'edge_dim': edge_dim, 'batch_size': batch_size, 'num_neighbors': num_neighbors, 'use_edge_features': True, }, 'timing': { 'training_duration': training_duration, 'inference_duration': inference_duration, }, 'sample_predictions': { 'predictions': all_predictions[sample_indices], 'predictions_denormalized': pred_denormalized[sample_indices], 'edge_correlations': all_edge_correlations[sample_indices], 'has_labels': all_has_labels[sample_indices], 'indices': sample_indices, }, } with open(f'out/{project_name}/predicted_coexpress.pkl', 'wb') as f: pickle.dump(results, f) print(f'Intermediate results saved to: out/{project_name}/predicted_coexpress.pkl')
Intermediate results saved to: out/top50_full_1e-3_L2/predicted_coexpress.pkl

13. Generate Full Co-expression Matrix

[ ]:
print('=' * 60)
print('Saving results')
print('=' * 60)

save_start_time = time.time()

result_df, save_stats = build_completed_coexpression_matrix(data, all_nodes_genes, pred_denormalized, coexpression_file, gene_to_idx, save_path=f'out/{project_name}/predicted_coexpress.pkl')
save_end_time = time.time()
save_duration = save_end_time - save_start_time

results['save_statistics'] = save_stats
results['timing']['save_duration'] = save_duration

with open(f'out/{project_name}/predicted_coexpress.pkl', 'wb') as f:
    pickle.dump(results, f)
print(f'Full results saved to: out/{project_name}/predicted_coexpress.pkl')
============================================================
Saving results
============================================================
Loading original data...
Original coexpression data: 18858 genes
Training edge statistics: total=1568784, labeled=200628, unlabeled=1368156
Step 1: Batch filling predicted edges...
Filling predicted edges: 100%|██████████| 1568784/1568784 [00:04<00:00, 320755.98it/s]
Filled 1568784 predicted edges
Step 2: Overwriting with original data...
Overwriting original data: 100%|██████████| 32016/32016 [03:18<00:00, 161.63it/s]
Overwrote 330639672 original edges
Step 3: Calculating final statistics...
Marking original edges: 100%|██████████| 32016/32016 [03:10<00:00, 168.00it/s]
Final statistics:
Total edges: 166443155
├─ Original data edges: 165319836 (99.3%)
├─ Labeled prediction edges: 0 (0.0%)
└─ Unlabeled prediction edges: 1123319 (0.7%)
Creating DataFrame...
Save complete!
File: predicted_coexpress.pkl
Matrix size: 32016 x 32016
Full results saved to: out/top50_full_1e-3_L2/predicted_coexpress.pkl

14. Analysis Report & Performance Summary

[ ]:
print_coexpression_completion_report(save_stats)

============================================================
Training Analysis
============================================================
Processing efficiency:
Using batch numpy operations, avoiding inefficient loops
Significantly improved processing speed

Training edge analysis:
Total training edges: 1,568,784
├─ Labeled edges: 200,628 (participate in loss calculation)
└─ Unlabeled edges: 1,368,156 (forward propagation only)
Unlabeled edge ratio: 87.2%

Final coexpression file:
Total genes: 32,016
Total edges: 166,443,155
├─ Original data edges: 165,319,836 (99.3%)
├─ Labeled prediction edges: 0 (0.0%)
└─ Unlabeled prediction edges: 1,123,319 (0.7%)

Core guarantees:
• Original high-quality data 100% retained
• Prediction edges based on consistent top-k logic
• Both labeled and unlabeled edges can be predicted
• Efficient processing, suitable for large-scale data
============================================================

15. File Validation

[38]:
print('Validating saved file...')
try:
    with open(f'out/{project_name}/predicted_coexpress.pkl', 'rb') as f:
        saved_df = pickle.load(f)

    n_genes = len(saved_df)
    n_matrix = len(saved_df['Coexpress'].iloc[0])
    print('Validation passed:')
    print(f'- Shape: {saved_df.shape}')
    print(f'- Gene count: {n_genes}')
    print(f'- Matrix size: {n_matrix} x {n_matrix}')

    sample_matrix = np.array([np.array(row) for row in saved_df['Coexpress'][:10]])
    is_symmetric = np.allclose(sample_matrix, sample_matrix.T, atol=1e-10)
    sym_label = 'passed' if is_symmetric else 'failed'
    print(f'- Symmetry check: {sym_label}')

    non_zero_count = sum(np.count_nonzero(np.array(row)) for row in saved_df['Coexpress'])
    sparsity = (1 - non_zero_count / n_genes**2) * 100
    print(f'- Non-zero count: {non_zero_count:,}')
    print(f'- Sparsity: {sparsity:.2f}%')

except Exception as e:
    print(f'Validation failed: {e}')
Validating saved file...
Validation failed: 'Coexpress'

16. Summary

[39]:
total_genes = save_stats['total_genes']
orig_edges = save_stats['original_edges']
pred_new = save_stats['labeled_pred_edges'] + save_stats['unlabeled_pred_edges']

print('Training complete!')
print('Key results:')
print(f'1. Built a full co-expression network with {total_genes:,} genes')
print(f'2. Retained all {orig_edges:,} original high-quality edges')
print(f'3. Added {pred_new:,} predicted edges')
print(f'4. Achieved R\u00b2 = {test_r2:.4f}')

dataset.print_memory_usage()
Training complete!
Key results:
1. Built a full co-expression network with 32,016 genes
2. Retained all 165,319,836 original high-quality edges
3. Added 1,123,319 predicted edges
4. Achieved R² = 0.6430
Current memory usage: 30.29 GB