-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MRG] Regression label for 2d classification data generation (#69)
* added regression labels * added a new test for regression label * added an exemple of uses of teh regression label and bug fixes the main bug that has been changes was due to the y.astype(int) in the return that rounded the float values when using regression labels * changed the plot to have only one colorbar * solved an issue with test_make_shifted_datasets_regression the test as using a method that has been changed, thus raising errors * removed two test that were not making sense the two tests that got removed were checking that the y-values were between 0 and 1, which should not necessarely be the case in regression * changed the size of the colorbar I just changed the size of the colorbar so that we have better looking plots next to it * use label instead of binary in _generate_data_2d_classif_subspace now _generate_data_2d_classif_subspace use the label that has been given in parametter instead of "binary" everytimes, additionally the example for the regression label use the subspace shift * made multiclass usable for subspace shift the main issue was that the y vlues that were generated weren't of the correct size (note the same as the X values) * corrected a typo * added a new test this test should cover the change over generate_data_2d_classif_subspace when using 'multiclass' or 'regression' label * updated a test with subset shift, the values are twice smaller for the default case * updated the test this was needed with teh previous changes * Update test_samples_generator.py * Update test_samples_generator.py * Update test_samples_generator.py * made the code follow linter's standards * correction of some mistake * changes for flake8 --------- Co-authored-by: Rouxben <[email protected]>
- Loading branch information
1 parent
97897dc
commit 8e8fc7b
Showing
3 changed files
with
175 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
""" | ||
Plot dataset source domain and shifted target domain | ||
==================================================== | ||
This illustrates the :func:`~skada.datasets.make_shifted_dataset` | ||
dataset generator. Each method consists of generating source data | ||
and shifted target data. We illustrate here: | ||
covariate shift, target shift, concept drift, and sample bias. | ||
See detailed description of each shift in [1]_. | ||
.. [1] Moreno-Torres, J. G., Raeder, T., Alaiz-Rodriguez, | ||
R., Chawla, N. V., and Herrera, F. (2012). | ||
A unifying view on dataset shift in classification. | ||
Pattern recognition, 45(1):521-530. | ||
""" | ||
# %% Imports | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
from skada.datasets import make_shifted_datasets | ||
from skada import source_target_split | ||
|
||
|
||
# %% Helper function | ||
|
||
def plot_shifted_dataset(shift, random_state=42): | ||
"""Plot source and shifted target data for a given type of shift. | ||
The possible shifts are 'covariate_shift', 'target_shift', | ||
'concept_drift' or 'subspace'. | ||
We use here the same random seed for multiple calls to | ||
ensure same distributions. | ||
""" | ||
X, y, sample_domain = make_shifted_datasets( | ||
n_samples_source=20, | ||
n_samples_target=20, | ||
shift=shift, | ||
noise=0.3, | ||
label="regression", | ||
random_state=random_state, | ||
) | ||
X_source, X_target, y_source, y_target = source_target_split( | ||
X, y, sample_domain=sample_domain) | ||
|
||
fig, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4)) | ||
fig.suptitle(shift.replace("_", " ").title(), fontsize=14) | ||
plt.subplots_adjust(bottom=0.15) | ||
ax1.scatter( | ||
X_source[:, 0], | ||
X_source[:, 1], | ||
c=y_source*10, | ||
vmax=1, | ||
alpha=0.5, | ||
) | ||
ax1.set_title("Source data") | ||
ax1.set_xlabel("Feature 1") | ||
ax1.set_ylabel("Feature 2") | ||
|
||
s = ax2.scatter( | ||
X_target[:, 0], | ||
X_target[:, 1], | ||
c=y_target*10, | ||
vmax=1, | ||
alpha=0.5, | ||
) | ||
ax2.set_title("Target data") | ||
ax2.set_xlabel("Feature 1") | ||
ax2.set_ylabel("Feature 2") | ||
|
||
fig.subplots_adjust(right=0.8) | ||
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.80]) | ||
cb = fig.colorbar(s, cax=cbar_ax) | ||
cb.set_label("y-value*10") | ||
|
||
plt.show() | ||
|
||
|
||
# %% Visualize shifted datasets | ||
|
||
for shift in [ | ||
"covariate_shift", | ||
"target_shift", | ||
"concept_drift", | ||
"subspace" | ||
]: | ||
plot_shifted_dataset(shift) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters