-
Notifications
You must be signed in to change notification settings - Fork 30
Description
Many models need to select entries from an array using symbolic indices. Think resource assignment, scheduling, routing—the same pattern.
I emulate those gathers by looping over every possible index and using where + sum to build a one-hot mask. For an array of length H, that costs on the order of 3×H nodes per lookup.
Having a built-in take (scalar and batched) would let us write value = take(arr, idx) just like we do in numpy.take. The engine can then translate it into a single node internally.
so, dramatically smaller graphs (fewer nodes, lower state size), faster model build/upload, and much cleaner model code.
Please consider exposing a gather primitive similar to numpy.take so we don’t have to build the workaround ourselves.
Current workaround
For each lookup I do something like:
# array: length H
# choice: symbolic index
mask = where(range_vec == choice, ones_vec, zeros_vec)
value = (mask * array).sum()
If the array has H elements, the loop expands to:
H equality checks (range_vec == choice)
H Where nodes to create the mask
Up to H multiplications for mask * array
H-1 additions to accumulate the result.
When this index-selection happens inside a bigger loop—for example, per job, per candidate window, per resource—the node count explodes.