Skip to content

Commit

Permalink
[MRG] Regression label for 2d classification data generation (#69)
Browse files Browse the repository at this point in the history
* added regression labels

* added a new test for regression label

* added an exemple of uses of teh regression label and bug fixes

the main bug that has been changes was due to the y.astype(int) in the return that rounded the float values when using regression labels

* changed the plot to have only one colorbar

* solved an issue with test_make_shifted_datasets_regression

the test as using a method that has been changed, thus raising errors

* removed two test that were not making sense

the two tests that got removed were checking that the y-values were between 0 and 1, which should not necessarely be the case in regression

* changed the size of the colorbar

I just changed the size of the colorbar so that we have better looking plots next to it

* use label instead of binary in _generate_data_2d_classif_subspace

now _generate_data_2d_classif_subspace use the label that has been given in parametter instead of "binary" everytimes, additionally the example for the regression label use the subspace shift

* made multiclass usable for subspace shift

the main issue was that the y vlues that were generated weren't of the correct size (note the same as the X values)

* corrected a typo

* added a new test

this test should cover the change over generate_data_2d_classif_subspace when using 'multiclass' or 'regression' label

* updated a test

with subset shift, the values are twice smaller for the default case

* updated the test

this was needed with teh previous changes

* Update test_samples_generator.py

* Update test_samples_generator.py

* Update test_samples_generator.py

* made the code follow linter's standards

* correction of some mistake

* changes for flake8

---------

Co-authored-by: Rouxben <[email protected]>
  • Loading branch information
BuenoRuben and BuenoRuben authored Feb 7, 2024
1 parent 97897dc commit 8e8fc7b
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 20 deletions.
87 changes: 87 additions & 0 deletions examples/datasets/plot_shifted_dataset_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
Plot dataset source domain and shifted target domain
====================================================
This illustrates the :func:`~skada.datasets.make_shifted_dataset`
dataset generator. Each method consists of generating source data
and shifted target data. We illustrate here:
covariate shift, target shift, concept drift, and sample bias.
See detailed description of each shift in [1]_.
.. [1] Moreno-Torres, J. G., Raeder, T., Alaiz-Rodriguez,
R., Chawla, N. V., and Herrera, F. (2012).
A unifying view on dataset shift in classification.
Pattern recognition, 45(1):521-530.
"""
# %% Imports

import matplotlib.pyplot as plt

from skada.datasets import make_shifted_datasets
from skada import source_target_split


# %% Helper function

def plot_shifted_dataset(shift, random_state=42):
"""Plot source and shifted target data for a given type of shift.
The possible shifts are 'covariate_shift', 'target_shift',
'concept_drift' or 'subspace'.
We use here the same random seed for multiple calls to
ensure same distributions.
"""
X, y, sample_domain = make_shifted_datasets(
n_samples_source=20,
n_samples_target=20,
shift=shift,
noise=0.3,
label="regression",
random_state=random_state,
)
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain)

fig, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4))
fig.suptitle(shift.replace("_", " ").title(), fontsize=14)
plt.subplots_adjust(bottom=0.15)
ax1.scatter(
X_source[:, 0],
X_source[:, 1],
c=y_source*10,
vmax=1,
alpha=0.5,
)
ax1.set_title("Source data")
ax1.set_xlabel("Feature 1")
ax1.set_ylabel("Feature 2")

s = ax2.scatter(
X_target[:, 0],
X_target[:, 1],
c=y_target*10,
vmax=1,
alpha=0.5,
)
ax2.set_title("Target data")
ax2.set_xlabel("Feature 1")
ax2.set_ylabel("Feature 2")

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.80])
cb = fig.colorbar(s, cax=cbar_ax)
cb.set_label("y-value*10")

plt.show()


# %% Visualize shifted datasets

for shift in [
"covariate_shift",
"target_shift",
"concept_drift",
"subspace"
]:
plot_shifted_dataset(shift)
41 changes: 32 additions & 9 deletions skada/datasets/_samples_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from scipy import signal
from scipy.fftpack import rfft, irfft
from scipy.stats import multivariate_normal

from sklearn.datasets import make_blobs

Expand Down Expand Up @@ -37,6 +38,7 @@ def _generate_data_2d_classif(n_samples, rng, label='binary'):
label : tuple, default='binary'
If 'binary, return binary class
If 'multiclass', return multiclass
if 'regression', return regression's y-values
"""
n2 = n_samples
n1 = n2 * 4
Expand Down Expand Up @@ -75,14 +77,20 @@ def _generate_data_2d_classif(n_samples, rng, label='binary'):
# make labels
if label == 'binary':
y = np.concatenate((np.zeros(n1), np.ones(4 * n2)), 0)
y = y.astype(int)
elif label == 'multiclass':
y = np.zeros(n1)
for i in range(4):
y = np.concatenate((y, (i + 1) * np.ones(n2)), 0)
y = y.astype(int)
elif label == 'regression':
# create label y with gaussian distribution
normal_rv = multivariate_normal(mu1, Sigma1)
y = normal_rv.pdf(x)
else:
raise ValueError(f"Invalid label value: {label}. The label should either be "
"'binary' or 'multiclass'")
return x, y.astype(int)
"'binary', 'multiclass' or 'regression'")
return x, y


def _generate_data_2d_classif_subspace(n_samples, rng, label='binary'):
Expand All @@ -98,6 +106,7 @@ def _generate_data_2d_classif_subspace(n_samples, rng, label='binary'):
label : tuple, default='binary'
If 'binary, return binary class
If 'multiclass', return multiclass
if 'regression', return regression's y-values
"""
n2 = n_samples
n1 = n2 * 2
Expand All @@ -124,15 +133,29 @@ def _generate_data_2d_classif_subspace(n_samples, rng, label='binary'):
# make labels
if label == 'binary':
y = np.concatenate((np.zeros(n1), np.ones(2 * n2)), 0)
y = y.astype(int)
elif label == 'multiclass':
y = np.zeros(n1)
for i in range(4):
y = np.concatenate((y, (i + 1) * np.ones(n2)), 0)
k = 4
if n1 % k != 0:
raise ValueError(f"Invalid value: {n_samples}. This value "
"multiplied by 2 should be a multiple from {k}")
for i in range(k):
y = np.concatenate((y, (i + 1) * np.ones(n1//k)), 0)
y = y.astype(int)
elif label == 'regression':
# When using the label regressio we use different values for sigma and mu,
# to have more interesting plots
Sigma1 = np.array([[1, 0], [0, 1]])
mu1 = np.array([0, 0])

# create label y with gaussian distribution
normal_rv = multivariate_normal(mu1, Sigma1)
y = normal_rv.pdf(x)
else:
raise ValueError(f"Invalid label value: {label}. The label should either be "
"'binary' or 'multiclass'")

return x, y.astype(int)
"'binary', 'multiclass' or 'regression'")
return x, y


def _generate_data_from_moons(n_samples, index, rng):
Expand Down Expand Up @@ -456,10 +479,10 @@ def make_shifted_datasets(

elif shift == "subspace":
X_source, y_source = _generate_data_2d_classif_subspace(
n_samples_source, rng, "binary"
n_samples_source, rng, label
)
X_target, y_target = _generate_data_2d_classif_subspace(
n_samples_target, rng, "binary"
n_samples_target, rng, label
)
X_target *= -1

Expand Down
67 changes: 56 additions & 11 deletions skada/datasets/tests/test_samples_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_make_shifted_blobs():

@pytest.mark.parametrize(
"shift",
["covariate_shift", "target_shift", "concept_drift"],
["covariate_shift", "target_shift", "concept_drift", "subspace"],
)
def test_make_shifted_datasets(shift):
X, y, sample_domain = make_shifted_datasets(
Expand All @@ -115,18 +115,25 @@ def test_make_shifted_datasets(shift):
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain
)

assert X_source.shape == (10 * 8, 2), "X source shape mismatch"
assert y_source.shape == (10 * 8,), "y source shape mismatch"
if shift == "subspace":
assert X_source.shape == (10 * 8 // 2, 2), "X source shape mismatch"
assert y_source.shape == (10 * 8 // 2,), "y source shape mismatch"
else:
assert X_source.shape == (10 * 8, 2), "X source shape mismatch"
assert y_source.shape == (10 * 8,), "y source shape mismatch"
assert np.unique(y_source).shape == (2,), "Unexpected number of cluster"
assert X_target.shape == (10 * 8, 2), "X target shape mismatch"
assert y_target.shape == (10 * 8,), "y target shape mismatch"
if shift == "subspace":
assert X_target.shape == (10 * 8 // 2, 2), "X target shape mismatch"
assert y_target.shape == (10 * 8 // 2,), "y target shape mismatch"
else :
assert X_target.shape == (10 * 8, 2), "X target shape mismatch"
assert y_target.shape == (10 * 8,), "y target shape mismatch"
assert np.unique(y_target).shape == (2,), "Unexpected number of cluster"


@pytest.mark.parametrize(
"shift",
["covariate_shift", "target_shift", "concept_drift"],
["covariate_shift", "target_shift", "concept_drift", "subspace"],
)
def test_make_multi_source_shifted_datasets(shift):
# test for multi-source
Expand All @@ -142,14 +149,52 @@ def test_make_multi_source_shifted_datasets(shift):
X, y, sample_domain=sample_domain
)

assert X_source.shape == (10 * 8, 2), "X source shape mismatch"
assert y_source.shape == (10 * 8,), "y source shape mismatch"
if shift == "subspace":
assert X_source.shape == (10 * 8 // 2, 2), "X source shape mismatch"
assert y_source.shape == (10 * 8 // 2,), "y source shape mismatch"
else:
assert X_source.shape == (10 * 8, 2), "X source shape mismatch"
assert y_source.shape == (10 * 8,), "y source shape mismatch"
assert np.unique(y_source).shape == (5,), "Unexpected number of cluster"
assert X_target.shape == (10 * 8, 2), "X target shape mismatch"
assert y_target.shape == (10 * 8,), "y target shape mismatch"
if shift == "subspace":
assert X_target.shape == (10 * 8 // 2, 2), "X target shape mismatch"
assert y_target.shape == (10 * 8 // 2,), "y target shape mismatch"
else :
assert X_target.shape == (10 * 8, 2), "X target shape mismatch"
assert y_target.shape == (10 * 8,), "y target shape mismatch"
assert np.unique(y_target).shape[0] <= 5, "Unexpected number of cluster"


@pytest.mark.parametrize(
"shift",
["covariate_shift", "target_shift", "concept_drift", "subspace"],
)
def test_make_shifted_datasets_regression(shift):
X, y, sample_domain = make_shifted_datasets(
n_samples_source=10,
n_samples_target=10,
shift=shift,
noise=None,
label="regression",
)
X, y, sample_domain = check_X_y_domain(X, y, sample_domain)
X_source, X_target, y_source, y_target = source_target_split(
X, y, sample_domain=sample_domain)

if shift == "subspace":
assert X_source.shape == (10 * 8 // 2, 2), "X source shape mismatch"
assert y_source.shape == (10 * 8 // 2,), "y source shape mismatch"
else:
assert X_source.shape == (10 * 8, 2), "X source shape mismatch"
assert y_source.shape == (10 * 8,), "y source shape mismatch"
if shift == "subspace":
assert X_target.shape == (10 * 8 // 2, 2), "X target shape mismatch"
assert y_target.shape == (10 * 8 // 2,), "y target shape mismatch"
else:
assert X_target.shape == (10 * 8, 2), "X target shape mismatch"
assert y_target.shape == (10 * 8,), "y target shape mismatch"


def test_make_subspace_datasets():
X, y, sample_domain = make_shifted_datasets(
n_samples_source=10,
Expand Down

0 comments on commit 8e8fc7b

Please sign in to comment.