-
Notifications
You must be signed in to change notification settings - Fork 119
[Feature] Add Beam Search as decoding strategy #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job on the PR! I left some mostly minor comments
rl4co/models/nn/dec_strategies.py
Outdated
|
||
if decoding_strategy not in strategy_registry: | ||
log.warning( | ||
f"Unknown environment name '{decoding_strategy}'. Available dynamic embeddings: {strategy_registry.keys()}. Defaulting to Sampling." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comment: change message (Unknown decoding strategy - Available strategies...)
rl4co/models/nn/dec_strategies.py
Outdated
return strategy_registry.get(decoding_strategy, Sampling)(**config) | ||
|
||
|
||
class DecodingStrategy(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comment: is there a specific reason for the class being an nn.Module
, since there is no forward
and parameters are not saved?
Also, calling super().__init__()
might have some (possibly minor) slowdowns if functionality is not used
outputs.append(log_p) | ||
actions.append(action) | ||
# setup decoding strategy | ||
self.decode_strategy: DecodingStrategy = get_decoding_strategy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if instantiating the strategy each time has some impact on speed? I guess not much. In case there is indeed some difference, it might be worth it to "cache" the strategy - i.e. since we are saving it in self
, then we could do something like:
if self.decode_strategy.name != decode_type:
self.decode_strategy: DecodingStrategy = get_decoding_strategy(
decode_type, **strategy_kwargs
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick review @fedebotu! Great catch, making the DecodingStrategy a subclass of nn.Module is indeed not necessary. After removing that, I did some profiling to check the impact of instantiating the strategy each time. Results indicate a super minor impact on speed, please see the image below. For the test, I implemented an optional cache, similar to what you proposed, and checked the speed with and without caching the strategy.
From a readability point-of-view, I would argue that the cache is not necessary. What do you suggest? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that the class instantiation is simple, there is basically no overhead and much cleaner code. Let's keep it as is!
If you have some time, how about making a simple notebook (like this) with training and evaluation on different decoding strategies? I think it would make a great tutorial 😄 |
Thanks a lot @LTluttmann for your contribution to RL4CO! |
Description
added base DecodingStrategy class which defines the following functions
implemented the greedy and sampling decoding strategies using the DecodingStrategy baseclass. They also support multistart by simply specifying "multistart_greedy" as decode_type for example. From a user perspective, nothing changes here.
added beam search using the DecodingStrategy baseclass overwriting the following functions:
added new test function to check that the new approach doesn't break anything
Motivation and Context
close #109
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!