-
Notifications
You must be signed in to change notification settings - Fork 147
Machine Learning Integration for DaCe (Autodiff - ONNX - PyTorch) #2164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from 250 commits
Commits
Show all changes
312 commits
Select commit
Hold shift + click to select a range
f03f622
Added SciPy to dependencies + formatting
affifboudaoud 008026c
Add pip dependency
and-ivanov aa4d836
Make daceml tests discoverable
and-ivanov 40e2338
fix frontend test
and-ivanov 0a9f71e
Fix paths
and-ivanov 988e230
Remove obsolete test that ensures only floating point computations ar…
and-ivanov f693f9b
Remove obsolete single state test
affifboudaoud b9812f5
Merge branch 'update_efforts' of https://github.com/affifboudaoud/dac…
affifboudaoud bc314d7
Remove deprecated np.bool
affifboudaoud 636d1ea
Fix API usage in test SDFGBackwardRunner
and-ivanov d1ade4a
Fix conv implementation for default strides and pads
and-ivanov 884d015
Fix ONNX operator expansions and their tests
and-ivanov a0dce78
Only use descriptor names to get AD data
affifboudaoud 946d1b9
Added simplify to test_nested to avoid FunctionCallRegions
affifboudaoud 25e3cc1
Added backward pass for Min reduction
affifboudaoud 9353d90
Remove init state transformation and test since we support multiple s…
affifboudaoud 431493b
Update bfs api
affifboudaoud baecfeb
Forward data if seprate_sdfgs
affifboudaoud f9e3ab6
Formatting + switch back to CPU
affifboudaoud 71a9761
Formatting + check for signature if seprate_sdfgs
affifboudaoud c2751ad
Fix for axes_arr of shape 1 + Formatting
affifboudaoud 1c6cf08
Multiple fixes to tests + Added main calls for debugging
affifboudaoud da7fe58
Make compatible with latest onnxruntime
and-ivanov ba9600c
Initialize unused arguments + Formatting
affifboudaoud 794cde0
Fix test_input_outputs.py
and-ivanov 9d49d6d
fix expansions
and-ivanov e3e8b03
fix test_bert.py: express softmax expansion in terms of simpler onnx …
and-ivanov 7b104f3
resolve onnx/onnxruntime versioning issues
and-ivanov 30e8394
Cleanup onnxruntime use
and-ivanov 0998b0b
fix test_shared_input_output.py
and-ivanov 59dc707
Fix expansions for Add,Sub,Mul,Div
and-ivanov 0e2e7d3
fix test_conv2d.py
and-ivanov afa8256
Changed test dtype to float64
affifboudaoud bebe5de
Fixed tensors_close print order
affifboudaoud e09cd82
Fix AccessSets analysis api call
affifboudaoud 6078c77
Formatting with yapf
affifboudaoud 0f5861a
Fix BackwardPass node creation and validation
affifboudaoud 3056c71
Remove FuncitonCallRegions before AD and initialize containers to zero
affifboudaoud 361746b
Fix batchnorm implementation
and-ivanov 70dc170
Express GlobalAveragePool through ReduceMean
and-ivanov 4c39da5
Iterate over state views not loop state views
affifboudaoud 0d5d613
Merge branch 'update_efforts' of https://github.com/affifboudaoud/dac…
affifboudaoud 5193e4f
Refactor ReduceMax,Min,Sum,Mean and fix ambiguity in passing scalars …
and-ivanov dd580c1
Add Llama Decoder inference test
affifboudaoud dd6a534
Add LlamaForCausalLM test
affifboudaoud 9e79460
Fix inf initialization + increase size limit for arrays
affifboudaoud 1a25959
Add new tensorproto format
affifboudaoud 4f3a5c1
Add new pure implementations + formatting
affifboudaoud 39f6899
Fix initialization for constant arrays that need to be forwarded to t…
affifboudaoud 21f1c12
Fix initialize_outputs_code call
affifboudaoud 2d85ec9
Added wcr sum to einsum backward output and fixed einsum expansion in…
affifboudaoud 9a75a94
Remove debug code
affifboudaoud 630771d
Add Llama decoder backward test
affifboudaoud 4edeaaa
Additional fixes to inf initializations
affifboudaoud 5a3e5a8
Add support for indirection
affifboudaoud 2f01229
Add initialization for integer tensors
affifboudaoud 0eff710
Remove constant inputs when constructing ONNX op replacements
affifboudaoud 511d9cd
Avoid gradient tracking for ONNX op attributes
affifboudaoud 2488b01
Fix ReduceSum backward implementation
affifboudaoud 4cbac69
Remove debug code
affifboudaoud 33e7e88
Enable ONNX simplify by default
affifboudaoud 447d6ba
Fix ReduceMax backward implementation
affifboudaoud 461a674
Add register storage ONNX codegen
affifboudaoud 741b0b1
Fix ReduceMean reduction conditions
affifboudaoud 1e8a1ca
Remove unnecessary wcr sum check
affifboudaoud 92357d2
Add pure BatchNorm implementation
affifboudaoud 573d753
Remove size limit for arrays
affifboudaoud c0ff565
Add ninja dependency and limit ONNX to 1.17
affifboudaoud 7f678ab
Add specific SDFG names to avoid folder mismatch with pytest
affifboudaoud ecced32
Avoid simplifying models for now
affifboudaoud a580723
Remove unused imports
affifboudaoud 5c068c8
Fix Einsum expansion to avoid duplicate descriptors
affifboudaoud 89846a3
Fix LayerNormalization backward implmenetation
affifboudaoud c69426a
Add full Llama backward test
affifboudaoud 452e446
Add pure ReduceSum implementation + Extend ReduceMean
affifboudaoud a71d04e
Fix LayerNormalization reduction axes
affifboudaoud 243b6d7
Remove obsolete tests and transformations
affifboudaoud 2200199
Update ORT C API and raw bindings
affifboudaoud 4215071
Set constant attributes for ONNX nodes
affifboudaoud 9d7dbe9
Improve tests by verifying all gradients + increase batch size
affifboudaoud f2143b8
Multiple fixes to reduction axes in pure expansions
affifboudaoud 965b327
Attempting to fix ORT C API
affifboudaoud aae9d13
Remove unnecessary views + obsolete GPU schedule code
affifboudaoud c977bce
Remove CPP implementations and improve softmax
affifboudaoud ac72044
Remove old Pow implementation
affifboudaoud 314402c
Fix forwarded value non-zero initialization
affifboudaoud 217fcfa
Remove debug SDFG save
affifboudaoud 86147c4
Add zero initializations
affifboudaoud 9f09de1
Add CopyToMap for GPU pass
affifboudaoud ea3c0b7
Merge remote-tracking branch 'origin/main' into dace_ad
affifboudaoud d2ad6b9
Merge lefover
affifboudaoud f04c534
Adapt to new API from merge
affifboudaoud cbdcc3d
Removed seprate dir for NPBench AD and added AD test prototype to k2mm
affifboudaoud 78fe0e5
Added AD NPBench tests
affifboudaoud bc52b50
Add expand operator and default value for steps in Slice
affifboudaoud 4b10931
Add all AD NPBench tests
affifboudaoud 5898e30
Fix gradient clearing
affifboudaoud 4137b5d
Formatting
affifboudaoud 052d82b
Remove obsolete tests
affifboudaoud 8fef9c6
Minor changes to tests + Formatting
affifboudaoud c02a43d
Formatting
affifboudaoud 0ee5b94
Avoid DDE in constant folding + Formatting
affifboudaoud 8045b02
Fix boolean tensor initialization
affifboudaoud 2ccb46d
Add Dropout forward impl + Formatting
affifboudaoud 40daa99
Fixes to BatchNorm + Formatting
affifboudaoud 8dd6764
Disable auto-opt by default
affifboudaoud 6bd42af
Add hooks before function init
affifboudaoud e2701ad
Formatting
affifboudaoud 84f46e0
Formatting
affifboudaoud 31ee8cb
Check for FunctionCallRegion in autodiff analysis
affifboudaoud a55aa67
Gradient clearing for single value arrays + Isolated node removal
affifboudaoud 7974ac2
Fix codegen for Indices subsets
affifboudaoud 1ebd677
Set transformers version to 4.5
affifboudaoud daa3a74
Remove GPU test for now
affifboudaoud 30078b5
Remove GPU tests for now
affifboudaoud 614fc0b
Remove unnecessary fixtures and remaining GPU tests
affifboudaoud 2b492bf
Restructure tests and add onnx marker
affifboudaoud 085d4d4
Update pytest marker
affifboudaoud 725e8d6
Remove AD auto-opt until transformed into passes
affifboudaoud b80d2b4
Remove ONNXRuntime dependency
affifboudaoud 3e87eb9
Remove AD auto-opt
affifboudaoud 7f4a561
[Restructuring] Moved functions to utils and removed experimental dyn…
affifboudaoud acf0def
Seprate SDFG element reversal from generator
affifboudaoud 2973c30
Separate more functions to utils and dace_nodes
affifboudaoud 703b1f3
[Restructuring] Moved storing and recomputation strategies into own dir
affifboudaoud faa472e
Fix typo
affifboudaoud 84dd7fc
Improve documentation
affifboudaoud afe1ecc
Remove unnecessary ONNXRuntime backend
affifboudaoud a8f411a
Remove onnx reporter
affifboudaoud e4ef648
Remove unnecessary testing dir
affifboudaoud 365ff71
Add design documents for each module
affifboudaoud 40c323d
Improve tests error messages and formatting
affifboudaoud f11ec23
Better documentation
affifboudaoud fcc14a5
Fix assertion in dlpack test
affifboudaoud f22dc61
Add comments
affifboudaoud 1b8546a
Make sure to compare to dace gradients when testing
affifboudaoud 595c755
Remove OpenBLAS dependency
affifboudaoud d65ff92
Add midding test packages
affifboudaoud e4eb08c
Merge remote-tracking branch 'origin/main' into dace_ad
affifboudaoud 8975401
Add missing package + Formatting
affifboudaoud 6168e9f
Allow Python 3.13 and ONNX 1.18
affifboudaoud 1fd631f
Set onnx IR version explicitly
affifboudaoud 073292a
Pre-commit formatting
affifboudaoud fdc0e3f
Serialization fixes
affifboudaoud 4d0792b
Fix paths for cpp extensions
affifboudaoud 877833a
Unique auto_opt name and expansion edge case
affifboudaoud 4a69aa9
Skip some AD tests until serialization issue is fixed
affifboudaoud 2b5d29a
Revert to main code
affifboudaoud 57168e0
Remove conda specific import
affifboudaoud 37030b1
Use expanded sdfgs instead of function call
affifboudaoud f096ad0
Make Torch and ONNX dependencies optional
affifboudaoud 747e1e5
Update CI installation
affifboudaoud 1bd2154
Update all CI installations
affifboudaoud 8b73bf5
Avoid conflicting names got batch size in MKL implementation
affifboudaoud fafc460
Build Torch module in unique dir to avoid baton issues
affifboudaoud 2aa4e3a
Attempting to reduce CI runtime with smaller sizes
affifboudaoud d2d31b9
Simplify durbin test
affifboudaoud f4b2c7c
Simplify resent
affifboudaoud 1dc4d8b
Even smaller sizes for cavity_flow
affifboudaoud ad34aba
Avoid data race in loop lifiting test
affifboudaoud f07a6e6
Formatting
affifboudaoud 477738b
Set JAX version to avoid conflict with cupy
affifboudaoud dec4703
set JAX to <= 0.6.2
affifboudaoud 341d3b3
Smaller inputs for Cholesky
affifboudaoud 21a2538
Make ReplacementTransformation abstract to pass coverage tests
affifboudaoud c998f1e
Remove redundant ReverseReduceMax class
affifboudaoud adb8873
Merge remote-tracking branch 'origin/main' into dace_ad
affifboudaoud 7db9d1c
Remove redundant reduction code
affifboudaoud af5b3aa
Remove duplicate code and TODOs
affifboudaoud 613f940
Add copyright headers
affifboudaoud 141dee7
Revert SDFG validation changes
affifboudaoud 8dcb470
Enable make_transients_persistent in auto_opt
affifboudaoud 1195dfc
Revert DDE changes
affifboudaoud 5c1c342
Remove comment
affifboudaoud a958728
Merge remote-tracking branch 'origin/main' into dace_ad
affifboudaoud 8dc611e
Remove unused global variable
affifboudaoud eb6b06e
Allow MKL in ONNX auto opt
affifboudaoud fa012d7
Improve documentation
affifboudaoud be1de1a
Improve error message
affifboudaoud c5f7b23
Merge remote-tracking branch 'origin/dace_ad' into dace_ad
affifboudaoud 5c5a3ef
Add non-pytest test calls in main
affifboudaoud dd247c8
Categorized ONNX pure implementations
affifboudaoud 838ac7b
Fix race condition view storing and add imports
affifboudaoud 0515a64
Add autodiff test call to main for debugging
affifboudaoud df0091b
Replace ONNXRuntime shape inference with ONNX shape inference
affifboudaoud 6598235
Merge branch 'main' into dace_ad
affifboudaoud 52439bb
Add onnxscript dependency
affifboudaoud cfc5cbc
Merge remote-tracking branch 'origin/dace_ad' into dace_ad
affifboudaoud b592bae
Use ONNXRuntime shape inference from package and keep ONNX as fallback
affifboudaoud cd99538
Disable Dynamo which is default in torch > 2.9
affifboudaoud 9515d57
Remove ONNXruntime tools fallback
affifboudaoud b56a90a
Check the shape inference output for incomplete processing
affifboudaoud 48a4005
Update dace/autodiff/autodiff.md
affifboudaoud a4376b5
Update dace/autodiff/autodiff.md
affifboudaoud a65e45a
Update dace/autodiff/backward_pass_generator.py
affifboudaoud 3206f16
Fix typo
affifboudaoud 93b7596
Merge branch 'dace_ad' of https://github.com/spcl/dace into dace_ad
affifboudaoud 2013cb8
Fix stringdoc
affifboudaoud 7ac0eb1
Add loss function to the autodiff example
affifboudaoud d5a1680
Use SDFG API to generate new names and symbols
affifboudaoud d558984
Fix static dtype in backward implementations
affifboudaoud da03c60
Use dace pi instead of hardcoded constant
affifboudaoud bacf827
Use sympy instead of regex matching
affifboudaoud c03be81
Use mapped tasklet API instead of building store maps manually
affifboudaoud be79d6e
Use symbolic affine expression matching for loop iterators and fix fi…
affifboudaoud 350b3ae
Add copyright header for test files
affifboudaoud 567cd7c
Remove dead code function and fix isinstance checks
affifboudaoud d7e981c
Fix is_previously_written analysis and decrease lu test size
affifboudaoud 6449e2b
Fix np.has_path use
affifboudaoud ca39caa
[Restructuring] Move ML frontends to dace.frontend.ml and enable impo…
affifboudaoud 9ba1608
Added lazy import for torch components and enabled new decorator @dac…
affifboudaoud 9006e42
Remove unnecessary data files and add .onnx .bin files to .gitignore
affifboudaoud 35b8a71
Remove single state requirement from SDFGBackwardRunner
affifboudaoud d624623
Remove unconditonals (1) from backward InterStateEdges
affifboudaoud 60e14b7
Add multi-state autodiff unit tests
affifboudaoud 5d994a9
Disable serialization for a few tests instead of completely skipping …
affifboudaoud 3a5ec29
Register ONNX expansion class for serialization
affifboudaoud 0e4f25a
Run ML tests as a seprate Action in CI
affifboudaoud 9225664
Add xdist_group for large models to avoid crashing in CI
affifboudaoud 26cffb5
Fix xdist marker
affifboudaoud a3a8856
Run CI ML tests sequentially and only with Python 3.13
affifboudaoud f074009
Merge remote-tracking branch 'origin/main' into dace_ad
affifboudaoud 30b5b04
Remove unnecessary install of ml deps in hardware CI
affifboudaoud 7a8bf79
Remove unnecessary install of ml deps in heterogeneous CI
affifboudaoud c1124de
Merge branch 'dace_ad' of https://github.com/spcl/dace into dace_ad
affifboudaoud d975c87
Update dace/autodiff/data_forwarding/recompute.py
affifboudaoud 8ead6ba
Remove ONNX op files and build models at runtime
affifboudaoud 4b0e88e
Fixes to recomputation + Add dill dependency back
affifboudaoud 45bdcf8
Separate testing and ml-testing extra_requires
affifboudaoud 388131c
Merge remote-tracking branch 'origin/main' into dace_ad
affifboudaoud 1e913c2
Add importorskip for ml only deps + remove redundant forward tests
affifboudaoud 8b031b2
Add tests/data to gitignore
affifboudaoud 535b54b
Improve doc strings and remove unnecessary installs from ml-ci
affifboudaoud 385d2a0
Improve docstrings + avoid asserting function calls + better Symbol h…
affifboudaoud a7968e7
Multiple CI adaptions + set networkx to 3.5 for now
affifboudaoud 9903346
Remove reduction code duplication + hardcoded datatypes
affifboudaoud 743f4af
rename torch dir to avoid pytest conflict
affifboudaoud 6f5b497
Replace find_str_not_in_set + Simplify ParameterArray creation
affifboudaoud 90c77e2
Removed dace/util
affifboudaoud a0b2bc0
Use debugprint + Simpler python_pure_op implementations
affifboudaoud 988ae5b
Remove TF imports + Add create_child_generator
affifboudaoud e3f544d
Update dace/autodiff/data_forwarding/store.py
affifboudaoud 516e8bb
Update dace/autodiff/data_forwarding/store.py
affifboudaoud 63c5292
Typo fixes. Co-Authored-By: Tal Ben-Nun
affifboudaoud d125416
Fix layernorm type issue
affifboudaoud 1aa572a
Fix extremal reduction backward implementation and add unit tests
affifboudaoud 6f2e7b7
Remove sdfg_name conftest fixture and add unique name to tests
affifboudaoud b31646d
Fix onnx proto conversion in onnx-1.20
affifboudaoud 141de99
Increase reduction tolerance
affifboudaoud 1196889
Update MD files
affifboudaoud File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| name: Machine Learning and Autodiff Tests | ||
|
|
||
| on: | ||
| push: | ||
| branches: [ main, ci-fix ] | ||
| pull_request: | ||
| branches: [ main, ci-fix ] | ||
| merge_group: | ||
| branches: [ main, ci-fix ] | ||
|
|
||
| concurrency: | ||
| group: ${{github.workflow}}-${{github.ref}} | ||
| cancel-in-progress: true | ||
|
|
||
| jobs: | ||
| test: | ||
| if: "!contains(github.event.pull_request.labels.*.name, 'no-ci')" | ||
| runs-on: ubuntu-latest | ||
| strategy: | ||
| matrix: | ||
| python-version: ['3.13'] | ||
| simplify: [0,1,autoopt] | ||
|
|
||
| steps: | ||
| - uses: actions/checkout@v4 | ||
| with: | ||
| submodules: 'recursive' | ||
| - name: Set up Python ${{ matrix.python-version }} | ||
| uses: actions/setup-python@v5 | ||
| with: | ||
| python-version: ${{ matrix.python-version }} | ||
| - name: Install dependencies | ||
| run: | | ||
| sudo apt-get update | ||
| sudo apt-get install -y libyaml-dev cmake | ||
| sudo apt-get install -y libblas-dev libopenblas-dev liblapacke-dev | ||
| python -m pip install --upgrade pip | ||
| pip install flake8 pytest-xdist coverage | ||
| pip install -e ".[ml-testing,ml]" | ||
| curl -Os https://uploader.codecov.io/latest/linux/codecov | ||
| chmod +x codecov | ||
|
|
||
| - name: Test with pytest | ||
| run: | | ||
| export NOSTATUSBAR=1 | ||
| export DACE_testing_serialization=1 | ||
| export DACE_testing_deserialize_exception=1 | ||
| export DACE_cache=unique | ||
| if [ "${{ matrix.simplify }}" = "autoopt" ]; then | ||
| export DACE_optimizer_automatic_simplification=1 | ||
| export DACE_optimizer_autooptimize=1 | ||
| echo "Auto-optimization heuristics" | ||
| else | ||
| export DACE_optimizer_automatic_simplification=${{ matrix.simplify }} | ||
| fi | ||
| pytest --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=300 -v -m "(torch or onnx or autodiff) and not gpu" | ||
| ./codecov | ||
|
|
||
| - uses: codecov/codecov-action@v4 | ||
| with: | ||
| token: ${{ secrets.CODECOV_TOKEN }} | ||
| verbose: true |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -195,3 +195,8 @@ _build/ | |
|
|
||
| # Ignoring the test junk | ||
| _all_tests/ | ||
|
|
||
|
|
||
| # Ignore downloaded ONNX models | ||
| /*.onnx | ||
| /*.bin | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| # Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. | ||
| """ | ||
| DaCe Automatic Differentiation (AD) System. | ||
|
|
||
| This module provides reverse-mode automatic differentiation for DaCe programs, | ||
| enabling automatic computation of gradients for optimized numerical kernels. | ||
|
|
||
| Main Components | ||
| --------------- | ||
| - **add_backward_pass**: Main entry point for adding backward pass to an SDFG | ||
| - **BackwardPassGenerator**: Core algorithm for generating backward passes | ||
| - **BackwardImplementation**: ABC for implementing operation-specific backward rules | ||
| - **BackwardContext**: Context information for backward pass generation | ||
| - **BackwardResult**: Result of backward pass generation with forward/backward SDFGs | ||
| - **AutoDiffException**: Base exception for autodiff errors | ||
|
|
||
| Key Features | ||
| ------------ | ||
| - Support for control flow (loops, conditionals) | ||
| - Data forwarding strategies (store vs recompute tradeoffs) | ||
| - Extensible backward implementations for library nodes | ||
| - Integration with PyTorch autograd | ||
| - Automatic memory management for intermediate values | ||
|
|
||
|
|
||
| """ | ||
|
|
||
| from .base_abc import BackwardImplementation, BackwardContext, BackwardResult, AutoDiffException | ||
| from .backward_pass_generator import BackwardPassGenerator | ||
| from .autodiff import add_backward_pass | ||
|
|
||
| try: | ||
| from .torch import make_backward_function | ||
| TORCH_INTEGRATION_AVAILABLE = True | ||
| except ImportError: | ||
| make_backward_function = None | ||
| TORCH_INTEGRATION_AVAILABLE = False | ||
|
|
||
| import sys | ||
| from . import library | ||
|
|
||
| __all__ = [ | ||
| # Main API | ||
| "add_backward_pass", | ||
| # Core classes | ||
| "BackwardPassGenerator", | ||
| "BackwardContext", | ||
| "BackwardResult", | ||
| # Extension points | ||
| "BackwardImplementation", | ||
| # Exceptions | ||
| "AutoDiffException", | ||
| # Submodules | ||
| "library", | ||
| ] | ||
|
|
||
| if TORCH_INTEGRATION_AVAILABLE: | ||
| __all__.append("make_backward_function") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| # Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. | ||
| """ | ||
affifboudaoud marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Analysis helpers for autodiff | ||
| """ | ||
| from typing import Dict, Set, Tuple, Optional | ||
| import collections | ||
|
|
||
| import networkx as nx | ||
|
|
||
| from dace.sdfg import SDFG, SDFGState, nodes, utils as sdfg_utils | ||
| from dace.transformation.passes import analysis | ||
| from dace.sdfg.state import FunctionCallRegion | ||
|
|
||
| AccessSets = Dict[SDFGState, Tuple[Set[str], Set[str]]] | ||
|
|
||
|
|
||
| def dependency_analysis(sdfg: SDFG) -> Dict[str, Set[str]]: | ||
| """ | ||
| Analyze read dependencies of arrays in the SDFG. | ||
|
|
||
| :param sdfg: SDFG to analyze | ||
| :return: A dictionary mapping array names to a list of read dependencies. | ||
| """ | ||
|
|
||
| # FIXME can be made more efficient | ||
| dependencies = nx.DiGraph() | ||
| for sdfg_node in sdfg.nodes(): | ||
| if isinstance(sdfg_node, SDFGState): | ||
| for node in sdfg_node.data_nodes(): | ||
| for edge in sdfg_node.edge_bfs(node, reverse=True): | ||
| dependencies.add_edge(node.data, edge.data.data) | ||
| elif isinstance(sdfg_node, FunctionCallRegion): | ||
alexnick83 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for state in sdfg_node.nodes(): | ||
| assert isinstance(state, SDFGState) | ||
| for node in state.data_nodes(): | ||
| for edge in state.edge_bfs(node, reverse=True): | ||
| dependencies.add_edge(node.data, edge.data.data) | ||
|
|
||
| dependencies = nx.transitive_closure(dependencies) | ||
| result = {} | ||
| for array in dependencies: | ||
| result[array] = {nbr for nbr in dependencies.neighbors(array)} | ||
| return result | ||
|
|
||
|
|
||
| def inverse_reachability(sdfg: SDFG) -> Dict[SDFGState, Set[SDFGState]]: | ||
|
|
||
| reachability = analysis.StateReachability().apply_pass(sdfg, {}) | ||
| inverse_reachability = collections.defaultdict(set) | ||
| # iterate over cfg_ids | ||
| for cfg_id in reachability.keys(): | ||
| for pred, successors in reachability[cfg_id].items(): | ||
| for successor in successors: | ||
| inverse_reachability[successor].add(pred) | ||
|
|
||
| return inverse_reachability | ||
|
|
||
|
|
||
| def is_previously_written(sdfg: SDFG, | ||
| state: SDFGState, | ||
| node: nodes.Node, | ||
| array_name: str, | ||
| access_sets: Optional[AccessSets] = None) -> bool: | ||
| """ | ||
| Determine whether the given array name was written before the current node. | ||
|
|
||
| :param sdfg: the sdfg containing the node | ||
| :param state: the state containing the node | ||
| :param node: the node to check | ||
| :param array_name: the array name to check | ||
| :return: True if the array was written before the node, False otherwise. | ||
| """ | ||
|
|
||
| if access_sets is None: | ||
| access_sets = analysis.AccessSets().apply_pass(sdfg, {}) | ||
|
|
||
| reachable = inverse_reachability(sdfg) | ||
|
|
||
| # Check the current state | ||
| for subgraph in sdfg_utils.concurrent_subgraphs(state): | ||
| if node in subgraph.nodes(): | ||
| # Get all the access nodes in the subgraph to the same data | ||
| for other_node in subgraph.data_nodes(): | ||
| if other_node != node and other_node.data == array_name: | ||
| # Check if this is a write node | ||
| for in_edge in subgraph.in_edges(other_node): | ||
| if in_edge.data.data == array_name: | ||
| # Check if there's a path to our node, | ||
| # since we only care about writes that happen before the current node | ||
| if nx.has_path(state.nx, other_node, node): | ||
| return True | ||
| else: | ||
| # This is not our current subgraph, check the write states | ||
| _, write_set = subgraph.read_and_write_sets() | ||
| if array_name in write_set: | ||
| return True | ||
|
|
||
| # Check other states | ||
| for predecessor in reachable[state]: | ||
| _, write_set = access_sets[predecessor] | ||
| if array_name in write_set: | ||
| return True | ||
| return False | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very nonstandard piece of code. Do we need it? If nobody imports
dace.mlthen the time will not be spent AFAIU.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was added so that we can do:
Without always importing the ml modules to save on import time. We can remove this code and note that a user needs to import dace.ml to be able to use the decorator.
@
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that requiring to import
dace.mlis very reasonable and what most python packages do.