VAE-এর latent z কোন distribution থেকে আসে? Data-এর true posterior p(z|x) intractable — normalizing constant-এ ∫p(x|z)p(z)dz compute করা অসম্ভব। Variational Inference (VI) এই intractable posterior-কে একটি simpler family q(z|x) দিয়ে approximate করে — modern generative AI (VAE, Diffusion) এর মূল idea।
The Intractable Posterior Problem
Bayesian rule:
Denominator p(x) (marginal likelihood / evidence) — high-dimensional integral, intractable।
Solution: p(z|x)-কে approximate করি একটি tractable q(z|x; φ) দিয়ে।
KL Divergence Objective
q কতটা ভালো approximation সেটা measure করি KL divergence দিয়ে:
Problem: p(z|x) জানা নেই, তাই log p(z|x) compute করা যায় না!
ELBO — Evidence Lower BOund
Trick: KL-এর expression re-arrange করলে:
KL(q \| p(z|x)) \geq 0 — তাই ELBO ≤ log p(x) (hence "lower bound")।
Objective: ELBO maximize করলে দুটি কাজ হয়:
- \mathbb{E}_q[\log p(x|z)] বাড়ে → reconstruction quality বাড়ে।
- KL(q \| p(z)) কমে → q prior p(z)-এর কাছে চলে আসে (regularization)।
Reparameterization Trick
ELBO-র gradient \nabla_\phi \mathbb{E}_{q_\phi}[f(z)] বের করতে হবে। কিন্তু expectation-এর ভেতর φ আছে।
Trick: z-কে φ-independent noise-এর function হিসেবে লিখি:
তখন expectation ε-এর উপর, যার φ-এর সাথে কোনো relation নেই — gradient সরাসরি pass হয়!
Mean-Field Approximation
q(z)-এর family কী হবে? সবচেয়ে simple:
Each latent independent — "mean-field" assumption (physics থেকে আসা terminology)।
Coordinate ascent: একে একে প্রতিটি q_i update — Variational Bayes / CAVI।
Limitation: true posterior-এ correlation থাকলে capture করতে পারে না — structured variational methods লাগে।
Amortized Variational Inference
Traditional VI: প্রতিটি data point-এর জন্য আলাদা variational parameters — O(N × d)।
Amortized VI (VAE): একটি shared neural network q_φ(z|x) সব data point-এর জন্য:
Inference cost O(N × d) থেকে O(model params) — নতুন data point-এও instant inference!
Python: ELBO Computation
import torch
import torch.nn as nn
import torch.distributions as dist
def compute_elbo(x, encoder, decoder):
"""
encoder: x -> (mu, log_var)
decoder: z -> x_recon
"""
mu, log_var = encoder(x)
# Reparameterization trick
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
z = mu + std * eps
# Reconstruction
x_recon = decoder(z)
recon_loss = nn.functional.mse_loss(x_recon, x, reduction='sum')
# KL divergence: N(mu, sigma^2) || N(0, I)
kl_loss = -0.5 * torch.sum(1 + log_var - mu**2 - torch.exp(log_var))
# ELBO = -recon_loss - kl_loss (we minimize negative ELBO)
elbo = -recon_loss - kl_loss
return elbo, recon_loss, kl_loss
# Example dimensions
x = torch.randn(32, 784) # batch of MNIST images
encoder = nn.Sequential(
nn.Linear(784, 400), nn.ReLU(),
nn.Linear(400, 20 * 2) # mu and log_var for 20-dim latent
)
decoder = nn.Sequential(
nn.Linear(20, 400), nn.ReLU(),
nn.Linear(400, 784), nn.Sigmoid()
)
elbo, recon, kl = compute_elbo(x, encoder, decoder)
print(f"ELBO: {elbo.item():.2f}, Recon: {recon.item():.2f}, KL: {kl.item():.2f}")Modern Variational Methods
- Normalizing Flows — q(z)-কে invertible transformations দিয়ে complex distribution-এ transform।
- Diffusion Models — variational bound on log-likelihood, ELBO-র generalization।
- Neural Processes — VI for meta-learning, few-shot prediction।
- Stochastic Gradient VI — mini-batch দিয়ে large-scale VI, similar to SGD but for posterior approximation।
Practice Tasks
- ELBO-র derivation নিজে কাগজে করুন — log p(x) = ... থেকে শুরু করে।
- KL N(μ₁, σ₁²) || N(μ₂, σ₂²)-র closed form বের করুন।
- Reparameterization trick ছাড়া gradient কেন unbiased পাওয়া যায় না?
- Mean-field assumption কখন fail করে? উদাহরণ দিন।
Interview Questions
- VAE-তে ELBO কেন maximize করি?
- Reparameterization trick vs score function estimator (REINFORCE) — পার্থক্য?
- Amortized inference-এ "amortized" কী meaning?
- KL divergence কেন asymmetric? Forward vs reverse KL-এর implication কী?
Summary · সারসংক্ষেপ
- True posterior p(z|x) intractable — VI simpler q(z|x) দিয়ে approximate করে।
- ELBO = reconstruction − KL(q||prior) — maximize করলে q true posterior-এর কাছাকাছি আসে।
- Reparameterization trick = gradient through stochastic node pass করা।
- Amortized VI (VAE) = shared network সব data point-এর জন্য, scalable inference।
- Normalizing flows, diffusion models — VI-র modern extensions।