Source code for rl4co.models.zoo.common.search.base
from typing import Any, Optional, Union
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.utils.data import Dataset
from rl4co.models.rl.common.base import RL4COLitModule
[docs]class SearchBase(RL4COLitModule):
"""Base class for search algorithms. Search algorithms
are used onlin to find better solutions for a given dataset, i.e.
given a policy, improve (a part of) its parameters such that
the policy performs better on the given dataset.
Note:
By default, we use manual optimization to handle the search.
Args:
env: RL4CO environment
policy: policy network
dataset: dataset to use for training
batch_size: batch size
**kwargs: additional arguments
"""
def __init__(
self,
env,
policy,
dataset: Union[Dataset, str],
batch_size: int = 1,
max_iters: int = 100,
max_runtime: Optional[int] = 86_400,
save_path: Optional[str] = None,
**kwargs,
):
self.save_hyperparameters(logger=False)
super().__init__(env, policy, **kwargs)
self.dataset = dataset
self.automatic_optimization = False # we optimize manually
[docs] def setup(self, stage="fit"):
"""Setup the dataset and attributes.
The RL4COLitModulebase class automatically loads the data.
"""
if isinstance(self.dataset, str):
# load from file
self.dataset = self.env.dataset(filename=self.dataset)
# Set all datasets and batch size as the same
for split in ["train", "val", "test"]:
setattr(self, f"{split}_dataset", self.dataset)
setattr(self, f"{split}_batch_size", self.hparams.batch_size)
# Setup loggers
self.setup_loggers()
[docs] def on_train_batch_start(self, batch: Any, batch_idx: int):
"""Called before training (i.e. search) for a new batch begins.
This can be used to perform changes to the model or optimizer at the start of each batch.
"""
pass # Implement in subclass
[docs] def training_step(self, batch, batch_idx):
"""Main search loop. We use the training step to effectively adapt to a `batch` of instances."""
raise NotImplementedError("Implement in subclass")
[docs] def on_train_batch_end(
self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
"""Called when the train batch ends. This can be used for
instance for logging or clearing cache.
"""
pass # Implement in subclass
[docs] def on_train_epoch_end(self) -> None:
"""Called when the train ends."""
pass # Implement in subclass
[docs] def validation_step(self, batch: Any, batch_idx: int):
"""Not used during search"""
pass
[docs] def test_step(self, batch: Any, batch_idx: int):
"""Not used during search"""
pass