From bf113302940a66672b1e8b233171b5fc95d3e544 Mon Sep 17 00:00:00 2001 From: Oleksii Kachaiev Date: Mon, 18 Dec 2023 15:35:51 +0100 Subject: [PATCH] Add code cells markup for dataset examples to make them interactive (#39) * Add cells markup for dataset examples to make the interactive * Unify random seed printing * Fix flake --- .../datasets/plot_dataset_from_moons_distribution.py | 8 +++++--- examples/datasets/plot_shifted_dataset.py | 5 +++++ examples/datasets/plot_variable_frequency_dataset.py | 10 +++++++++- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/examples/datasets/plot_dataset_from_moons_distribution.py b/examples/datasets/plot_dataset_from_moons_distribution.py index e86802da..ea1f64f4 100644 --- a/examples/datasets/plot_dataset_from_moons_distribution.py +++ b/examples/datasets/plot_dataset_from_moons_distribution.py @@ -6,7 +6,7 @@ dataset generator. Each method consists of generating source data and shifted target data. """ - +# %% Imports import numpy as np import matplotlib.pyplot as plt @@ -17,6 +17,7 @@ # ensure same distributions RANDOM_SEED = np.random.randint(2**10) +# %% Generate and visualize single-source single-target dataset X, y, sample_domain = make_dataset_from_moons_distribution( pos_source=0.1, @@ -63,6 +64,8 @@ plt.show() +# %% Generate and visualize multi-source multi-target dataset + X, y, sample_domain = make_dataset_from_moons_distribution( pos_source=[0.1, 0.3, 0.5], pos_target=[0.4, 0.9], @@ -78,7 +81,6 @@ fig, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4)) fig.suptitle('Multi-source and Multi-target', fontsize=14) plt.subplots_adjust(bottom=0.15) -# for i in sample_domain and positive for i in np.unique(domain_source): ax1.scatter( @@ -114,4 +116,4 @@ plt.show() -print("The data was generated from (random_state=%d):" % RANDOM_SEED) +print(f"The data was generated from (random_state={RANDOM_SEED})") diff --git a/examples/datasets/plot_shifted_dataset.py b/examples/datasets/plot_shifted_dataset.py index 80e10fc3..6db96ab9 100644 --- a/examples/datasets/plot_shifted_dataset.py +++ b/examples/datasets/plot_shifted_dataset.py @@ -13,6 +13,7 @@ A unifying view on dataset shift in classification. Pattern recognition, 45(1):521-530. """ +# %% Imports import matplotlib.pyplot as plt @@ -20,6 +21,8 @@ from skada import source_target_split +# %% Helper function + def plot_shifted_dataset(shift, random_state=42): """Plot source and shifted target data for a given type of shift. @@ -78,6 +81,8 @@ def plot_shifted_dataset(shift, random_state=42): plt.show() +# %% Visualize shifted datasets + for shift in [ "covariate_shift", "target_shift", diff --git a/examples/datasets/plot_variable_frequency_dataset.py b/examples/datasets/plot_variable_frequency_dataset.py index d0d0c7ee..1b03a17c 100644 --- a/examples/datasets/plot_variable_frequency_dataset.py +++ b/examples/datasets/plot_variable_frequency_dataset.py @@ -6,6 +6,8 @@ dataset generator. Each method consists of generating source data and shifted target data. """ +# %% Imports + import numpy as np import matplotlib.pyplot as plt @@ -16,6 +18,8 @@ # ensure same distributions RANDOM_SEED = np.random.randint(2**10) +# %% Generate the dataset + X, y, sample_domain = make_variable_frequency_dataset( n_samples_source=1, n_samples_target=1, @@ -31,6 +35,8 @@ X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain) +# %% Visualize the signal + fig, ax = plt.subplots(3, 2, sharex="all", sharey="all", figsize=(8, 4)) plt.subplots_adjust(bottom=0.15) fig.suptitle('Signal visualisation') @@ -67,6 +73,8 @@ ax[0, 0].legend() plt.show() +# %% Visualize PSD shift + fig, ax = plt.subplots(3, 2, sharex="all", sharey="all", figsize=(8, 4)) plt.subplots_adjust(bottom=0.15) fig.suptitle('PSD shift') @@ -107,4 +115,4 @@ ax[2, 1].set_xlabel("Frequency") plt.show() -print("The data was generated from (random_state=%d):" % RANDOM_SEED) +print(f"The data was generated from (random_state={RANDOM_SEED})")