Skip to content



Bases: VIPRSGrid

The VIPRSGridSearch class is an extension of the VIPRSGrid class that implements grid search for the VIPRS models. The grid search procedure fits multiple models to the data, each with different hyperparameters, and selects the best model based on user-defined criteria.

The criteria supported are:

  • ELBO: The model with the highest ELBO is selected.
  • validation: The model with the highest R^2 on the validation set is selected.
  • pseudo_validation: The model with the highest pseudo-validation R^2 is selected.

Note that the validation and pseudo_validation criteria require the user to provide validation data in the form of paired genotype/phenotype data or external GWAS summary statistics.

Source code in viprs/model/gridsearch/
class VIPRSGridSearch(VIPRSGrid):
    The `VIPRSGridSearch` class is an extension of the `VIPRSGrid` class that
    implements grid search for the `VIPRS` models. The grid search procedure
    fits multiple models to the data, each with different hyperparameters,
    and selects the best model based on user-defined criteria.

    The criteria supported are:

    * `ELBO`: The model with the highest ELBO is selected.
    * `validation`: The model with the highest R^2 on the validation set is selected.
    * `pseudo_validation`: The model with the highest pseudo-validation R^2 is selected.

    Note that the `validation` and `pseudo_validation` criteria require the user to provide
    validation data in the form of paired genotype/phenotype data or external GWAS summary


    def __init__(self,
        Initialize the `VIPRSGridSearch` model.

        :param gdl: An instance of `GWADataLoader`
        :param grid: An instance of `HyperparameterGrid`
        :param kwargs: Additional keyword arguments to pass to the parent `VIPRSGrid` class.

        super().__init__(gdl, grid=grid, **kwargs)

    def select_best_model(self, validation_gdl=None, criterion='ELBO'):
        From the grid of models that were fit to the data, select the best 
        model according to the specified `criterion`. If the criterion is the ELBO,
        the model with the highest ELBO will be selected. If the criterion is
        validation or pseudo-validation, the model with the highest R^2 on the
        validation set will be selected.

        :param validation_gdl: An instance of `GWADataLoader` containing data from the validation set.
        Must be provided if criterion is `validation` or `pseudo_validation`.
        :param criterion: The criterion for selecting the best model. 
        Options are: (`ELBO`, `validation`, `pseudo_validation`)

        assert criterion in ('ELBO', 'validation', 'pseudo_validation')

        # Extract the models that converged successfully:
        models_converged = self.valid_terminated_models

        if np.sum(models_converged) < 1:
            raise ValueError("No models converged successfully. "
                             "Cannot select best model.")
        elif np.sum(models_converged) == 1:
            best_model_idx = np.where(models_converged)[0][0]

            if criterion == 'ELBO':
                elbo = self.history['ELBO'][-1]
                elbo[~models_converged] = -np.inf
                best_model_idx = np.argmax(self.history['ELBO'][-1])
            elif criterion == 'validation':

                assert validation_gdl is not None
                assert validation_gdl.sample_table is not None
                assert validation_gdl.sample_table.phenotype is not None

                from viprs.eval.continuous_metrics import r2

                prs = self.predict(test_gdl=validation_gdl)
                prs_r2 = np.array([r2(prs[:, i], validation_gdl.sample_table.phenotype)
                                   for i in range(self.n_models)])
                prs_r2[~models_converged] = -np.inf
                self.validation_result['Validation_R2'] = prs_r2
                best_model_idx = np.argmax(prs_r2)
            elif criterion == 'pseudo_validation':

                pseudo_r2 = self.pseudo_validate(validation_gdl, metric='r2')
                pseudo_r2[~models_converged] = -np.inf
                self.validation_result['Pseudo_Validation_R2'] = pseudo_r2
                best_model_idx = np.argmax(np.nan_to_num(pseudo_r2, nan=0., neginf=0., posinf=0.))

        if int(self.verbose) > 1:
  "> Based on the {criterion} criterion, selected model: {best_model_idx}")
  "> Model details:\n")
  [best_model_idx, :])

        # -----------------------------------------------------------------------
        # Update the variational parameters and their dependencies to only select the best model:
        for param in (self.pip, self.post_mean_beta, self.post_var_beta,
                      self.var_gamma, self.var_mu, self.var_tau,
                      self.eta, self.zeta, self.q, self._log_var_tau):

            for c in param:
                param[c] = param[c][:, best_model_idx]

        # Update sigma_epsilon:
        self.sigma_epsilon = self.sigma_epsilon[best_model_idx]

        # Update sigma_g:
        self._sigma_g = self._sigma_g[best_model_idx]

        # Update sigma beta:
        if isinstance(self.tau_beta, dict):
            for c in self.tau_beta:
                self.tau_beta[c] = self.tau_beta[c][:, best_model_idx]
            self.tau_beta = self.tau_beta[best_model_idx]

        # Update pi

        if isinstance(self.pi, dict):
            for c in self.pi:
                self.pi[c] = self.pi[c][:, best_model_idx]
            self.pi = self.pi[best_model_idx]

        # Set the number of models to 1:
        self.n_models = 1

        # -----------------------------------------------------------------------

        return self

__init__(gdl, grid, **kwargs)

Initialize the VIPRSGridSearch model.


Name Type Description Default

An instance of GWADataLoader


An instance of HyperparameterGrid


Additional keyword arguments to pass to the parent VIPRSGrid class.

Source code in viprs/model/gridsearch/
def __init__(self,
    Initialize the `VIPRSGridSearch` model.

    :param gdl: An instance of `GWADataLoader`
    :param grid: An instance of `HyperparameterGrid`
    :param kwargs: Additional keyword arguments to pass to the parent `VIPRSGrid` class.

    super().__init__(gdl, grid=grid, **kwargs)

select_best_model(validation_gdl=None, criterion='ELBO')

From the grid of models that were fit to the data, select the best model according to the specified criterion. If the criterion is the ELBO, the model with the highest ELBO will be selected. If the criterion is validation or pseudo-validation, the model with the highest R^2 on the validation set will be selected.


Name Type Description Default

An instance of GWADataLoader containing data from the validation set. Must be provided if criterion is validation or pseudo_validation.


The criterion for selecting the best model. Options are: (ELBO, validation, pseudo_validation)

Source code in viprs/model/gridsearch/
def select_best_model(self, validation_gdl=None, criterion='ELBO'):
    From the grid of models that were fit to the data, select the best 
    model according to the specified `criterion`. If the criterion is the ELBO,
    the model with the highest ELBO will be selected. If the criterion is
    validation or pseudo-validation, the model with the highest R^2 on the
    validation set will be selected.

    :param validation_gdl: An instance of `GWADataLoader` containing data from the validation set.
    Must be provided if criterion is `validation` or `pseudo_validation`.
    :param criterion: The criterion for selecting the best model. 
    Options are: (`ELBO`, `validation`, `pseudo_validation`)

    assert criterion in ('ELBO', 'validation', 'pseudo_validation')

    # Extract the models that converged successfully:
    models_converged = self.valid_terminated_models

    if np.sum(models_converged) < 1:
        raise ValueError("No models converged successfully. "
                         "Cannot select best model.")
    elif np.sum(models_converged) == 1:
        best_model_idx = np.where(models_converged)[0][0]

        if criterion == 'ELBO':
            elbo = self.history['ELBO'][-1]
            elbo[~models_converged] = -np.inf
            best_model_idx = np.argmax(self.history['ELBO'][-1])
        elif criterion == 'validation':

            assert validation_gdl is not None
            assert validation_gdl.sample_table is not None
            assert validation_gdl.sample_table.phenotype is not None

            from viprs.eval.continuous_metrics import r2

            prs = self.predict(test_gdl=validation_gdl)
            prs_r2 = np.array([r2(prs[:, i], validation_gdl.sample_table.phenotype)
                               for i in range(self.n_models)])
            prs_r2[~models_converged] = -np.inf
            self.validation_result['Validation_R2'] = prs_r2
            best_model_idx = np.argmax(prs_r2)
        elif criterion == 'pseudo_validation':

            pseudo_r2 = self.pseudo_validate(validation_gdl, metric='r2')
            pseudo_r2[~models_converged] = -np.inf
            self.validation_result['Pseudo_Validation_R2'] = pseudo_r2
            best_model_idx = np.argmax(np.nan_to_num(pseudo_r2, nan=0., neginf=0., posinf=0.))

    if int(self.verbose) > 1:"> Based on the {criterion} criterion, selected model: {best_model_idx}")"> Model details:\n")[best_model_idx, :])

    # -----------------------------------------------------------------------
    # Update the variational parameters and their dependencies to only select the best model:
    for param in (self.pip, self.post_mean_beta, self.post_var_beta,
                  self.var_gamma, self.var_mu, self.var_tau,
                  self.eta, self.zeta, self.q, self._log_var_tau):

        for c in param:
            param[c] = param[c][:, best_model_idx]

    # Update sigma_epsilon:
    self.sigma_epsilon = self.sigma_epsilon[best_model_idx]

    # Update sigma_g:
    self._sigma_g = self._sigma_g[best_model_idx]

    # Update sigma beta:
    if isinstance(self.tau_beta, dict):
        for c in self.tau_beta:
            self.tau_beta[c] = self.tau_beta[c][:, best_model_idx]
        self.tau_beta = self.tau_beta[best_model_idx]

    # Update pi

    if isinstance(self.pi, dict):
        for c in self.pi:
            self.pi[c] = self.pi[c][:, best_model_idx]
        self.pi = self.pi[best_model_idx]

    # Set the number of models to 1:
    self.n_models = 1

    # -----------------------------------------------------------------------

    return self