Skip to content

Conversation

LTluttmann
Copy link
Contributor

Description

  • added base DecodingStrategy class which defines the following functions

    1. pre_decoder_hook: called before the while loop in the autoregressive decoder
    2. step: is called in every iteration of the while loop and samples an action given the log probabilities. The actual logic is specified by the subclasses through a _step function
    3. post_decoder_hook: called right after the while loop, mainly used to concatenate outputs
  • implemented the greedy and sampling decoding strategies using the DecodingStrategy baseclass. They also support multistart by simply specifying "multistart_greedy" as decode_type for example. From a user perspective, nothing changes here.

  • added beam search using the DecodingStrategy baseclass overwriting the following functions:

    1. pre_decoder_hook: similar to the multistart options, the beam width (if not specified) is determined and the actions of the first iteration are simply the first beam_width nodes (this could be improved in the future)
    2. _step: performs the actual beam search
    3. post_decoder_hook: performs back tracking and brings the solution in the right format
  • added new test function to check that the new approach doesn't break anything

Motivation and Context

close #109

  • I have raised an issue to propose this change

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link
Member

@fedebotu fedebotu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job on the PR! I left some mostly minor comments


if decoding_strategy not in strategy_registry:
log.warning(
f"Unknown environment name '{decoding_strategy}'. Available dynamic embeddings: {strategy_registry.keys()}. Defaulting to Sampling."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: change message (Unknown decoding strategy - Available strategies...)

return strategy_registry.get(decoding_strategy, Sampling)(**config)


class DecodingStrategy(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: is there a specific reason for the class being an nn.Module, since there is no forward and parameters are not saved?
Also, calling super().__init__() might have some (possibly minor) slowdowns if functionality is not used

outputs.append(log_p)
actions.append(action)
# setup decoding strategy
self.decode_strategy: DecodingStrategy = get_decoding_strategy(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if instantiating the strategy each time has some impact on speed? I guess not much. In case there is indeed some difference, it might be worth it to "cache" the strategy - i.e. since we are saving it in self, then we could do something like:

if self.decode_strategy.name != decode_type:
  self.decode_strategy: DecodingStrategy = get_decoding_strategy(
              decode_type, **strategy_kwargs
          )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick review @fedebotu! Great catch, making the DecodingStrategy a subclass of nn.Module is indeed not necessary. After removing that, I did some profiling to check the impact of instantiating the strategy each time. Results indicate a super minor impact on speed, please see the image below. For the test, I implemented an optional cache, similar to what you proposed, and checked the speed with and without caching the strategy.
From a readability point-of-view, I would argue that the cache is not necessary. What do you suggest? :)

Bildschirmfoto 2024-01-29 um 13 02 20

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that the class instantiation is simple, there is basically no overhead and much cleaner code. Let's keep it as is!

@fedebotu
Copy link
Member

If you have some time, how about making a simple notebook (like this) with training and evaluation on different decoding strategies? I think it would make a great tutorial 😄

@fedebotu fedebotu added the feature New Feature label Jan 24, 2024
@fedebotu fedebotu merged commit eb22897 into ai4co:main Jan 30, 2024
@fedebotu
Copy link
Member

Thanks a lot @LTluttmann for your contribution to RL4CO!
If you are interested in contributing with new features for the library, feel free to contact us (either here or through the AI4CO community Slack), we have a looong todo list ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New Feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Add Beam Search as decoding strategy
2 participants