-
Notifications
You must be signed in to change notification settings - Fork 998
fix: fix tensor grouping bug & optimize MMD calculation in causal_prediction/algorithms/regularization.py #1371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: JPZ4-5 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR addresses critical bugs in the Regularizer class and introduces performance optimizations to the MMD calculation in causal prediction algorithms. The changes aim to fix a tensor grouping bug that caused incorrect penalty calculations and improve computational efficiency by ~40%.
Key Changes:
- Replaced manual hashing-based grouping logic with
torch.unique(dim=0, return_inverse=True)to fix dictionary key collision issues across environments - Introduced
_optimized_mmd_penaltymethod with algebraic optimization to reduce control flow overhead - Added
_compute_conditional_penaltyhelper to centralize conditional penalty computation logic - Enforced fp64 precision during MMD accumulation for numerical stability
- Added
use_optimizationparameter (defaulting to False) to allow gradual migration to optimized implementations
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Thank you very much @JPZ4-5 for this contribution. The PR looks promising. @jivatneet could you take a look as well? (thank you!) |
Signed-off-by: JPZ4-5 <[email protected]>
Signed-off-by: JPZ4-5 <[email protected]>
|
Signed-off-by: JPZ4-5 <[email protected]>
|
It seems CI fail with System.IO.IOException: No space left on device. It looks like the runner ran out of disk space. Could you please trigger a re-run? |
This PR refactors the
Regularizerclass indowhy/causal_prediction/algorithms/regularization.pyto fix a critical logic error in cross-environment grouping and significantly improve computational efficiency. Given that this issue renders the current implementation mathematically incorrect and potentially harmful to model performance without raising errors, prompt review are highly recommended.Key Improvements
Replaced Grouping Logic:
factorsvector and dot product (grouping_data @ factors). Since GPU matrix multiplication (@) does not supportlongtypes, this required inefficient type casting betweenfloatandlong.torch.unique(dim=0, return_inverse=True)to handle grouping. This method is more robust, concise, and leverages native PyTorch optimizations without unnecessary type conversions.Bug Fix (Dictionary Key Issue):
tensor(1)from Env0 andtensor(1)from Env1) were treated as distinct objects. Consequently, incorrect MMD noise was added to the penalty because keys failed to collide across environments (as shown in the debug screenshot, identical keys from differentenvswere treated as different groups, leading to a wrong bigger penalty).torch.uniqueindices (or ensuring scalar keys are handled by value), ensuring data from different environments is correctly merged into the same pool.float64precision, casting back to the environment's default dtype (e.g.,float32) only after calculation.Empirical evidence and standard parameter search spaces suggest
gammais often very small (float32can lead to vanishing penalty terms or precision loss.float64ensures sufficient precision for the penalty accumulation.Benchmark: In local testing, this PR resulted in an approximate 40% speedup in training throughput (increasing from 2.5 it/s to 3.5 it/s). All 6 cases have tested.