"""Optimisation by exhaustive search, aka grid search."""
from typing import List
from hypertunity.domain import Domain, DomainNotIterableError, Sample
from hypertunity.optimisation.base import ExhaustedSearchSpaceError, Optimiser
__all__ = [
"GridSearch"
]
[docs]class GridSearch(Optimiser):
"""Grid search pseudo-optimiser."""
[docs] def __init__(self,
domain: Domain,
sample_continuous: bool = False,
seed: int = None):
"""Initialise the :class:`GridSearch` optimiser from a discrete domain.
If the domain contains continuous subspaces, then they could be sampled
if `sample_continuous` is enabled.
Args:
domain: :class:`Domain`. The domain to iterate over.
sample_continuous: (optional) :obj:`bool`. Whether to sample the
continuous subspaces of the domain.
seed: (optional) :obj:`int`. Seed for the sampling of the continuous
subspace if necessary.
"""
if domain.is_continuous and not sample_continuous:
raise DomainNotIterableError(
"Cannot perform grid search on (partially) continuous domain. "
"To enable grid search in this case, set the argument "
"'sample_continuous' to True."
)
super(GridSearch, self).__init__(domain)
(
discrete_domain,
categorical_domain,
continuous_domain
) = domain.split_by_type()
# unify the discrete and the categorical into one,
# as they can be iterated:
self.discrete_domain = discrete_domain + categorical_domain
if seed is not None:
self.continuous_domain = Domain(
continuous_domain.as_dict(), seed=seed
)
else:
self.continuous_domain = continuous_domain
self._discrete_domain_iter = iter(self.discrete_domain)
self._is_exhausted = len(self.discrete_domain) == 0
self.__exhausted_err = ExhaustedSearchSpaceError(
"The domain has been exhausted. Reset the optimiser to start again."
)
[docs] def run_step(self, batch_size: int = 1, **kwargs) -> List[Sample]:
"""Get the next `batch_size` samples from the Cartesian-product walk
over the domain.
Args:
batch_size: (optional) :obj:`int`. The number of samples to suggest
at once.
Returns:
A list of :class:`Sample` instances from the domain.
Raises:
:class:`ExhaustedSearchSpaceError`: if the (discrete part of the)
domain is fully exhausted and no samples can be generated.
Notes:
This method does not guarantee that the returned list of
:class:`Samples` will be of length `batch_size`. This is due to the
size of the domain and the fact that samples will not be repeated.
"""
if self._is_exhausted:
raise self.__exhausted_err
samples = []
for i in range(batch_size):
try:
discrete = next(self._discrete_domain_iter)
except StopIteration:
self._is_exhausted = True
break
if self.continuous_domain:
continuous = self.continuous_domain.sample()
samples.append(discrete + continuous)
else:
samples.append(discrete)
if samples:
return samples
raise self.__exhausted_err
[docs] def reset(self):
"""Reset the optimiser to the beginning of the Cartesian-product walk."""
super(GridSearch, self).reset()
self._discrete_domain_iter = iter(self.discrete_domain)
self._is_exhausted = len(self.discrete_domain) == 0