Skip to content

Conversation

@JPZ4-5
Copy link
Contributor

@JPZ4-5 JPZ4-5 commented Nov 27, 2025

This PR refactors the Regularizer class in dowhy/causal_prediction/algorithms/regularization.py to fix a critical logic error in cross-environment grouping and significantly improve computational efficiency. Given that this issue renders the current implementation mathematically incorrect and potentially harmful to model performance without raising errors, prompt review are highly recommended.

Key Improvements

  1. Replaced Grouping Logic:

    • Legacy: Relied on a manual hashing approach using a factors vector and dot product (grouping_data @ factors). Since GPU matrix multiplication (@) does not support long types, this required inefficient type casting between float and long.
    • New: Adopted torch.unique(dim=0, return_inverse=True) to handle grouping. This method is more robust, concise, and leverages native PyTorch optimizations without unnecessary type conversions.
  2. Bug Fix (Dictionary Key Issue):

    • Issue: The legacy implementation used PyTorch Tensors as keys for Python dictionaries. In cross-environment settings, identical scalar tensors from different environments (e.g., tensor(1) from Env0 and tensor(1) from Env1) were treated as distinct objects. Consequently, incorrect MMD noise was added to the penalty because keys failed to collide across environments (as shown in the debug screenshot, identical keys from different envs were treated as different groups, leading to a wrong bigger penalty).
    • Fix: The new implementation naturally resolves this by utilizing torch.unique indices (or ensuring scalar keys are handled by value), ensuring data from different environments is correctly merged into the same pool.
图片
  1. Algebraic Optimization & Throughput:
    • Refactored the MMD Penalty calculation to use an algebraically optimized form instead of nested Python loops, which significantly reduces control flow overhead and improves GPU throughput.
    • Formula:

$$ \sum_{i=1}^n\sum_{j=i+1}^n (K_{ii}+K_{jj}-2K_{ij})=(n-1)\sum_{i=1}^n K_{ii}-2\sum_{i=1}^n\sum_{j=i+1}^n K_{ij} $$

  1. Numerical Stability (Enforced fp64):
    • Change: Forced MMD accumulation to use float64 precision, casting back to the environment's default dtype (e.g., float32) only after calculation.
      Empirical evidence and standard parameter search spaces suggest gamma is often very small ($10^{-5}$ to $10^{-7}$). Calculating Gaussian kernels with such small values in float32 can lead to vanishing penalty terms or precision loss. float64 ensures sufficient precision for the penalty accumulation.

Benchmark: In local testing, this PR resulted in an approximate 40% speedup in training throughput (increasing from 2.5 it/s to 3.5 it/s). All 6 cases have tested.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses critical bugs in the Regularizer class and introduces performance optimizations to the MMD calculation in causal prediction algorithms. The changes aim to fix a tensor grouping bug that caused incorrect penalty calculations and improve computational efficiency by ~40%.

Key Changes:

  • Replaced manual hashing-based grouping logic with torch.unique(dim=0, return_inverse=True) to fix dictionary key collision issues across environments
  • Introduced _optimized_mmd_penalty method with algebraic optimization to reduce control flow overhead
  • Added _compute_conditional_penalty helper to centralize conditional penalty computation logic
  • Enforced fp64 precision during MMD accumulation for numerical stability
  • Added use_optimization parameter (defaulting to False) to allow gradual migration to optimized implementations

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@emrekiciman
Copy link
Member

Thank you very much @JPZ4-5 for this contribution. The PR looks promising. @jivatneet could you take a look as well? (thank you!)

@JPZ4-5
Copy link
Contributor Author

JPZ4-5 commented Nov 29, 2025

  1. Regarding E_eq_A=True Logic:
    Replacing torch.full(..., i) with attribute_labels[i] is suggested. However, according to the original CACM paper, when E_eq_A=True, the algorithm is explicitly designed to use the Environment index as the sensitive attribute, regardless of what constitutes the raw attribute_labels.
    Therefore, constructing the labels manually using the environment index i is the intended behavior.

  2. Code Fixes Applied:

    • Corrected features.dtype access.
    • Fixed the initialization of the covariance matrix in the else branch (using torch.zeros instead of .diag() to ensure correct shape for $N=1$).
    • Completed the missing docstrings.
    • Standardized the usage of torch.tensor vs tensor.
  3. Clarification on MMD Calculation & use_optimization:
    I want to clarify that the critical bug was solely in the grouping stage (using Tensors as dictionary keys), which I have fixed. (This sentence may help bot to understand)
    I retained the unoptimized path because it offers higher readability and facilitates easier extensibility for future custom kernels. Not all kernels may have a straightforward vectorized implementation for pooled data, and developers might prioritize readability/development efficiency over spending time on trivial algebraic optimizations for complex kernels(like me).
    The use_optimization flag allows developers to opt-in when they are using the standard gaussian_kernel (or others with clear efficiency gains from vectorization). This parameter can be easily toggled in the CACM class if needed.

Signed-off-by: JPZ4-5 <[email protected]>
@JPZ4-5
Copy link
Contributor Author

JPZ4-5 commented Dec 5, 2025

It seems CI fail with System.IO.IOException: No space left on device. It looks like the runner ran out of disk space. Could you please trigger a re-run?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants