From 5884c13d4d094c3e1975b81d49a385ca888b8ad6 Mon Sep 17 00:00:00 2001 From: antoinedemathelin Date: Fri, 16 Feb 2024 11:40:57 +0100 Subject: [PATCH 1/2] remove target labels --- examples/plot_method_comparison.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/plot_method_comparison.py b/examples/plot_method_comparison.py index 8f346eea..347ad7fa 100644 --- a/examples/plot_method_comparison.py +++ b/examples/plot_method_comparison.py @@ -38,6 +38,7 @@ ) from skada.datasets import make_shifted_datasets from skada import source_target_split +from skada.datasets import DomainAwareDataset # Use same random seed for multiple calls to make_datasets to # ensure same distributions @@ -119,6 +120,9 @@ Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain) + dataset = DomainAwareDataset([(Xs, ys, 's'), (Xt, yt, 't')]) + X, y, sample_domain = dataset.pack_train(as_sources=['s'], as_targets=['t']) + x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5 y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5 # just plot the dataset first From 47a635f101350e8362ca99ede80576f502e8ecf2 Mon Sep 17 00:00:00 2001 From: antoinedemathelin Date: Fri, 16 Feb 2024 14:26:13 +0100 Subject: [PATCH 2/2] use return_dataset argument --- examples/plot_method_comparison.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/examples/plot_method_comparison.py b/examples/plot_method_comparison.py index 347ad7fa..9faabf0c 100644 --- a/examples/plot_method_comparison.py +++ b/examples/plot_method_comparison.py @@ -37,8 +37,6 @@ CORAL ) from skada.datasets import make_shifted_datasets -from skada import source_target_split -from skada.datasets import DomainAwareDataset # Use same random seed for multiple calls to make_datasets to # ensure same distributions @@ -84,7 +82,8 @@ shift="covariate_shift", label="binary", noise=0.4, - random_state=RANDOM_SEED + random_state=RANDOM_SEED, + return_dataset=True ), make_shifted_datasets( n_samples_source=20, @@ -92,7 +91,8 @@ shift="target_shift", label="binary", noise=0.4, - random_state=RANDOM_SEED + random_state=RANDOM_SEED, + return_dataset=True ), make_shifted_datasets( n_samples_source=20, @@ -100,7 +100,8 @@ shift="concept_drift", label="binary", noise=0.4, - random_state=RANDOM_SEED + random_state=RANDOM_SEED, + return_dataset=True ), make_shifted_datasets( n_samples_source=20, @@ -108,7 +109,8 @@ shift="subspace", label="binary", noise=0.4, - random_state=RANDOM_SEED + random_state=RANDOM_SEED, + return_dataset=True ), ] @@ -116,12 +118,9 @@ # iterate over datasets for ds_cnt, ds in enumerate(datasets): # preprocess dataset, split into training and test part - X, y, sample_domain = ds - - Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain) - - dataset = DomainAwareDataset([(Xs, ys, 's'), (Xt, yt, 't')]) - X, y, sample_domain = dataset.pack_train(as_sources=['s'], as_targets=['t']) + X, y, sample_domain = ds.pack_train(as_sources=['s'], as_targets=['t']) + Xs, ys = ds.get_domain("s") + Xt, yt = ds.get_domain("t") x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5 y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5