1
+ #!/usr/bin/env python3
2
+
3
+ # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ import json
7
+ import yaml
8
+ import sys
9
+ import os
10
+ from typing import Dict , List , Any
11
+
12
+ def write_output (key : str , value : str ):
13
+ """Write GitHub Actions output."""
14
+ github_output = os .environ .get ('GITHUB_OUTPUT' , '/dev/null' )
15
+ with open (github_output , 'a' ) as f :
16
+ f .write (f"{ key } ={ value } \n " )
17
+ print (f"{ key } ={ value } " )
18
+
19
+ def write_json_file (filename : str , data : Any ):
20
+ """Write data to a JSON file."""
21
+ with open (filename , 'w' ) as f :
22
+ json .dump (data , f , indent = 2 )
23
+
24
+ def write_text_file (filename : str , content : str ):
25
+ """Write content to a text file."""
26
+ with open (filename , 'w' ) as f :
27
+ f .write (content )
28
+
29
+ def explode_std_versions (matrix_entries : List [Dict [str , Any ]]) -> List [Dict [str , Any ]]:
30
+ """Explode std arrays into individual entries."""
31
+ result = []
32
+ for entry in matrix_entries :
33
+ if 'std' in entry and isinstance (entry ['std' ], list ):
34
+ for std in entry ['std' ]:
35
+ new_entry = entry .copy ()
36
+ new_entry ['std' ] = std
37
+ result .append (new_entry )
38
+ else :
39
+ result .append (entry )
40
+ return result
41
+
42
+ def extract_matrix (file_path : str , matrix_type : str ):
43
+ """Extract and process the matrix configuration."""
44
+ try :
45
+ with open (file_path , 'r' ) as f :
46
+ data = yaml .safe_load (f )
47
+
48
+ if matrix_type not in data :
49
+ print (f"Error: Matrix type '{ matrix_type } ' not found in { file_path } " , file = sys .stderr )
50
+ sys .exit (1 )
51
+
52
+ matrix = data [matrix_type ]
53
+
54
+ # Write devcontainer version
55
+ devcontainer_version = data .get ('devcontainer_version' , '25.08' )
56
+ write_output ("DEVCONTAINER_VERSION" , devcontainer_version )
57
+
58
+ # Process nvcc matrix
59
+ if 'nvcc' not in matrix :
60
+ print (f"Error: 'nvcc' section not found in { matrix_type } matrix" , file = sys .stderr )
61
+ sys .exit (1 )
62
+
63
+ nvcc_matrix = matrix ['nvcc' ]
64
+ nvcc_full_matrix = explode_std_versions (nvcc_matrix )
65
+
66
+ write_output ("NVCC_FULL_MATRIX" , json .dumps (nvcc_full_matrix ))
67
+
68
+ # Extract unique CUDA versions
69
+ cuda_versions = list (set (entry ['cuda' ] for entry in nvcc_full_matrix ))
70
+ cuda_versions .sort ()
71
+ write_output ("CUDA_VERSIONS" , json .dumps (cuda_versions ))
72
+
73
+ # Extract unique host compilers
74
+ host_compilers = list (set (entry ['compiler' ]['name' ] for entry in nvcc_full_matrix ))
75
+ host_compilers .sort ()
76
+ write_output ("HOST_COMPILERS" , json .dumps (host_compilers ))
77
+
78
+ # Create per-cuda-compiler matrix
79
+ per_cuda_compiler = {}
80
+ for entry in nvcc_full_matrix :
81
+ key = f"{ entry ['cuda' ]} -{ entry ['compiler' ]['name' ]} "
82
+ if key not in per_cuda_compiler :
83
+ per_cuda_compiler [key ] = []
84
+ per_cuda_compiler [key ].append (entry )
85
+
86
+ write_output ("PER_CUDA_COMPILER_MATRIX" , json .dumps (per_cuda_compiler ))
87
+
88
+ # Create output directory and write detailed files (CCCL approach)
89
+ os .makedirs ("workflow" , exist_ok = True )
90
+
91
+ # Write individual output files for debugging and artifacts
92
+ write_json_file ("workflow/devcontainer_version.json" , {"version" : devcontainer_version })
93
+ write_json_file ("workflow/nvcc_full_matrix.json" , nvcc_full_matrix )
94
+ write_json_file ("workflow/cuda_versions.json" , cuda_versions )
95
+ write_json_file ("workflow/host_compilers.json" , host_compilers )
96
+ write_json_file ("workflow/per_cuda_compiler_matrix.json" , per_cuda_compiler )
97
+
98
+ # Write summary
99
+ summary = {
100
+ "total_matrix_entries" : len (nvcc_full_matrix ),
101
+ "cuda_compiler_combinations" : len (per_cuda_compiler ),
102
+ "cuda_versions" : cuda_versions ,
103
+ "host_compilers" : host_compilers
104
+ }
105
+ write_json_file ("workflow/matrix_summary.json" , summary )
106
+
107
+ # Write human-readable summary
108
+ summary_text = f"Matrix Summary:\n "
109
+ summary_text += f" Total matrix entries: { len (nvcc_full_matrix )} \n "
110
+ summary_text += f" CUDA versions: { ', ' .join (cuda_versions )} \n "
111
+ summary_text += f" Host compilers: { ', ' .join (host_compilers )} \n "
112
+ summary_text += f" CUDA-compiler combinations: { len (per_cuda_compiler )} \n \n "
113
+ summary_text += "Combinations:\n "
114
+ for key , entries in per_cuda_compiler .items ():
115
+ summary_text += f" { key } : { len (entries )} entries\n "
116
+
117
+ write_text_file ("workflow/matrix_summary.txt" , summary_text )
118
+
119
+ print (f"Successfully processed { len (nvcc_full_matrix )} matrix entries" , file = sys .stderr )
120
+ print (f"Generated { len (per_cuda_compiler )} cuda-compiler combinations" , file = sys .stderr )
121
+ print ("Matrix data written to workflow/ directory" , file = sys .stderr )
122
+
123
+ except FileNotFoundError :
124
+ print (f"Error: Matrix file '{ file_path } ' not found" , file = sys .stderr )
125
+ sys .exit (1 )
126
+ except yaml .YAMLError as e :
127
+ print (f"Error parsing YAML file '{ file_path } ': { e } " , file = sys .stderr )
128
+ sys .exit (1 )
129
+ except KeyError as e :
130
+ print (f"Error: Missing required key in matrix file: { e } " , file = sys .stderr )
131
+ sys .exit (1 )
132
+ except Exception as e :
133
+ print (f"Unexpected error processing matrix: { e } " , file = sys .stderr )
134
+ sys .exit (1 )
135
+
136
+ def main ():
137
+ if len (sys .argv ) != 3 :
138
+ print ("Usage: compute-matrix.py MATRIX_FILE MATRIX_TYPE" , file = sys .stderr )
139
+ print (" MATRIX_FILE : The path to the matrix file." , file = sys .stderr )
140
+ print (" MATRIX_TYPE : The desired matrix. Supported values: 'pull_request'" , file = sys .stderr )
141
+ sys .exit (1 )
142
+
143
+ matrix_file = sys .argv [1 ]
144
+ matrix_type = sys .argv [2 ]
145
+
146
+ if matrix_type != "pull_request" :
147
+ print (f"Error: Unsupported matrix type '{ matrix_type } '. Only 'pull_request' is supported." , file = sys .stderr )
148
+ sys .exit (1 )
149
+
150
+ print (f"Input matrix file: { matrix_file } " , file = sys .stderr )
151
+ print (f"Matrix Type: { matrix_type } " , file = sys .stderr )
152
+
153
+ # Show matrix file content for debugging
154
+ try :
155
+ with open (matrix_file , 'r' ) as f :
156
+ content = f .read ()
157
+ print ("Matrix file content:" , file = sys .stderr )
158
+ print (content , file = sys .stderr )
159
+ print ("=" * 50 , file = sys .stderr )
160
+ except Exception as e :
161
+ print (f"Warning: Could not read matrix file for debugging: { e } " , file = sys .stderr )
162
+
163
+ extract_matrix (matrix_file , matrix_type )
164
+
165
+ if __name__ == "__main__" :
166
+ main ()
0 commit comments