Shortcuts

Source code for rl4co.models.zoo.common.search.base

from typing import Any, Optional, Union

from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.utils.data import Dataset

from rl4co.models.rl.common.base import RL4COLitModule


[docs]class SearchBase(RL4COLitModule): """Base class for search algorithms. Search algorithms are used onlin to find better solutions for a given dataset, i.e. given a policy, improve (a part of) its parameters such that the policy performs better on the given dataset. Note: By default, we use manual optimization to handle the search. Args: env: RL4CO environment policy: policy network dataset: dataset to use for training batch_size: batch size **kwargs: additional arguments """ def __init__( self, env, policy, dataset: Union[Dataset, str], batch_size: int = 1, max_iters: int = 100, max_runtime: Optional[int] = 86_400, save_path: Optional[str] = None, **kwargs, ): self.save_hyperparameters(logger=False) super().__init__(env, policy, **kwargs) self.dataset = dataset self.automatic_optimization = False # we optimize manually
[docs] def setup(self, stage="fit"): """Setup the dataset and attributes. The RL4COLitModulebase class automatically loads the data. """ if isinstance(self.dataset, str): # load from file self.dataset = self.env.dataset(filename=self.dataset) # Set all datasets and batch size as the same for split in ["train", "val", "test"]: setattr(self, f"{split}_dataset", self.dataset) setattr(self, f"{split}_batch_size", self.hparams.batch_size) # Setup loggers self.setup_loggers()
[docs] def on_train_batch_start(self, batch: Any, batch_idx: int): """Called before training (i.e. search) for a new batch begins. This can be used to perform changes to the model or optimizer at the start of each batch. """ pass # Implement in subclass
[docs] def training_step(self, batch, batch_idx): """Main search loop. We use the training step to effectively adapt to a `batch` of instances.""" raise NotImplementedError("Implement in subclass")
[docs] def on_train_batch_end( self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: """Called when the train batch ends. This can be used for instance for logging or clearing cache. """ pass # Implement in subclass
[docs] def on_train_epoch_end(self) -> None: """Called when the train ends.""" pass # Implement in subclass
[docs] def validation_step(self, batch: Any, batch_idx: int): """Not used during search""" pass
[docs] def test_step(self, batch: Any, batch_idx: int): """Not used during search""" pass

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.