Skip to content

Commit

Permalink
Add code cells markup for dataset examples to make them interactive (#39
Browse files Browse the repository at this point in the history
)

* Add cells markup for dataset examples to make the interactive

* Unify random seed printing

* Fix flake
  • Loading branch information
kachayev authored Dec 18, 2023
1 parent fe42db9 commit bf11330
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
8 changes: 5 additions & 3 deletions examples/datasets/plot_dataset_from_moons_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
dataset generator. Each method consists of generating source data
and shifted target data.
"""

# %% Imports
import numpy as np
import matplotlib.pyplot as plt

Expand All @@ -17,6 +17,7 @@
# ensure same distributions
RANDOM_SEED = np.random.randint(2**10)

# %% Generate and visualize single-source single-target dataset

X, y, sample_domain = make_dataset_from_moons_distribution(
pos_source=0.1,
Expand Down Expand Up @@ -63,6 +64,8 @@

plt.show()

# %% Generate and visualize multi-source multi-target dataset

X, y, sample_domain = make_dataset_from_moons_distribution(
pos_source=[0.1, 0.3, 0.5],
pos_target=[0.4, 0.9],
Expand All @@ -78,7 +81,6 @@
fig, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4))
fig.suptitle('Multi-source and Multi-target', fontsize=14)
plt.subplots_adjust(bottom=0.15)
# for i in sample_domain and positive

for i in np.unique(domain_source):
ax1.scatter(
Expand Down Expand Up @@ -114,4 +116,4 @@

plt.show()

print("The data was generated from (random_state=%d):" % RANDOM_SEED)
print(f"The data was generated from (random_state={RANDOM_SEED})")
5 changes: 5 additions & 0 deletions examples/datasets/plot_shifted_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
A unifying view on dataset shift in classification.
Pattern recognition, 45(1):521-530.
"""
# %% Imports

import matplotlib.pyplot as plt

from skada.datasets import make_shifted_datasets
from skada import source_target_split


# %% Helper function

def plot_shifted_dataset(shift, random_state=42):
"""Plot source and shifted target data for a given type of shift.
Expand Down Expand Up @@ -78,6 +81,8 @@ def plot_shifted_dataset(shift, random_state=42):
plt.show()


# %% Visualize shifted datasets

for shift in [
"covariate_shift",
"target_shift",
Expand Down
10 changes: 9 additions & 1 deletion examples/datasets/plot_variable_frequency_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
dataset generator. Each method consists of generating source data
and shifted target data.
"""
# %% Imports

import numpy as np
import matplotlib.pyplot as plt

Expand All @@ -16,6 +18,8 @@
# ensure same distributions
RANDOM_SEED = np.random.randint(2**10)

# %% Generate the dataset

X, y, sample_domain = make_variable_frequency_dataset(
n_samples_source=1,
n_samples_target=1,
Expand All @@ -31,6 +35,8 @@

X_source, y_source, X_target, y_target = source_target_split(X, y, sample_domain)

# %% Visualize the signal

fig, ax = plt.subplots(3, 2, sharex="all", sharey="all", figsize=(8, 4))
plt.subplots_adjust(bottom=0.15)
fig.suptitle('Signal visualisation')
Expand Down Expand Up @@ -67,6 +73,8 @@
ax[0, 0].legend()
plt.show()

# %% Visualize PSD shift

fig, ax = plt.subplots(3, 2, sharex="all", sharey="all", figsize=(8, 4))
plt.subplots_adjust(bottom=0.15)
fig.suptitle('PSD shift')
Expand Down Expand Up @@ -107,4 +115,4 @@
ax[2, 1].set_xlabel("Frequency")
plt.show()

print("The data was generated from (random_state=%d):" % RANDOM_SEED)
print(f"The data was generated from (random_state={RANDOM_SEED})")

0 comments on commit bf11330

Please sign in to comment.