Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] _MultitaskGaussianLikelihoodBase does not pass additional arguments in marginal #2630

Open
adrianLepp opened this issue Feb 1, 2025 · 2 comments
Labels

Comments

@adrianLepp
Copy link

🐛 Bug

In the call of self._shaped_noise_covar in _MultitaskGaussianLikelihoodBase
additional arguments are not passed to the method as it is done in _GaussianLikelihoodBase
.
I wanted to implement the FixedNoiseGaussianLikelihood for a Multitask GP. here it should be possible to pass a noise parameter with shape of the outputs to the likelihood as given in the example:

train_x = torch.randn(55, 2)
noises = torch.ones(55) * 0.01
likelihood = FixedNoiseGaussianLikelihood(noise=noises, learn_additional_noise=True)
pred_y = likelihood(gp_model(train_x))

test_x = torch.randn(21, 2)
test_noises = torch.ones(21) * 0.02
pred_y = likelihood(gp_model(test_x), noise=test_noises)

To reproduce

This my class (WIP) based on considered solutions in #901

class FixedTaskNoiseMultitaskLikelihood(_MultitaskGaussianLikelihoodBase):
    def __init__(
        self,
        noise: torch.Tensor,
        has_task_noise: bool = False,
        task_noise: torch.Tensor = None,
        task_noise_factor: torch.Tensor = None,
        *args,
        **kwargs
    ) -> None:
        noise_covar = FixedGaussianNoise(noise=noise)
        super().__init__(noise_covar=noise_covar, *args, **kwargs)
        self.has_global_noise = False
        self.has_task_noise = has_task_noise

        if self.has_task_noise:
            if task_noise is not None:
                self.task_noise = task_noise
                self.task_noise_factor = None
            elif task_noise_factor is not None:
                self.task_noise_factor = task_noise_factor
                self.task_noise = None
            else:
                raise ValueError("Must supply task noise or task noise factor")
        
    def _shaped_noise_covar(self, base_shape: torch.Size,  *params, add_noise=True, **kwargs): 
        if self.has_task_noise and self.task_noise is not None:
            if 'noise' in kwargs is not None:
                return DiagLinearOperator(kwargs['noise'])
            else:
                return DiagLinearOperator(self.task_noise)

        else:
            data_noise = self.noise_covar(*params, shape=torch.Size((base_shape[-2],)), **kwargs)

            if len(params) > 0:
            # we can infer the shape from the params
                shape = None
            else:
            # here shape[:-1] is the batch shape requested, and shape[-1] is `n`, the number of points
                shape = base_shape

            _data_noise = self.noise_covar(*params, shape=shape, **kwargs)

            if not self.has_task_noise: 
                eye = torch.ones(1, device=data_noise.device, dtype=data_noise.dtype)
                task_noise = ConstantDiagLinearOperator(
                    eye, diag_shape=torch.Size((self.num_tasks,))
                )
            else: # task_noise_factor
                task_noise_factor = self.task_noise_factor.to(device=data_noise.device, dtype=data_noise.dtype)
                task_noise = DiagLinearOperator(task_noise_factor)
        
            return KroneckerProductLinearOperator(data_noise, task_noise)

calling

task_noise = torch.ones_like(train_y).flatten()
likelihood = FixedTaskNoiseMultitaskLikelihood(num_tasks=num_tasks, noise, rank=num_tasks, has_task_noise=True, task_noise=task_noise)

test_noises = torch.ones(torch.Size((model.num_tasks,test_x.shape[0]))).flatten()
likelihood(gp_model(test_x), noise=test_noises)

should pass noise to _shaped_noise_covar.

@adrianLepp adrianLepp added the bug label Feb 1, 2025
@gpleiss
Copy link
Member

gpleiss commented Feb 8, 2025

Gotcha. Any chance you can make a PR to fix this bug?

@adrianLepp
Copy link
Author

@gpleiss Technically yes. I just need to see when I find the time to check the contribution guidelines and to set up the development guideline. Since it is a small change, I wonder If it would be more effective if someone already involved would implement this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants