Skip to content

Commit

Permalink
Merge pull request #8 from kkovary/fix-int32-out-of-bounds
Browse files Browse the repository at this point in the history
fix int32 out of bounds
  • Loading branch information
daenuprobst authored Feb 19, 2025
2 parents 134ca06 + 2afcc2c commit e394ea2
Show file tree
Hide file tree
Showing 25 changed files with 970 additions and 367 deletions.
32 changes: 14 additions & 18 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Python Package using Conda
name: Python Package using uv

on: [push]

Expand All @@ -9,26 +9,22 @@ jobs:
max-parallel: 5

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
python-version: 3.7
- name: Add conda to system path
run: |
# $CONDA is an environment variable pointing to the root of the miniconda directory
echo $CONDA/bin >> $GITHUB_PATH
enable-cache: true
cache-dependency-glob: "uv.lock"
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
run: |
conda env update --file environment.yml --name base
- name: Lint with flake8
uv sync
- name: Run pre-commit
run: |
conda install flake8
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
uv run pre-commit run --all-files
- name: Test with tox
run: |
pip install -e .[testing]
tox
uv run tox
17 changes: 17 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.3
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: local
hooks:
- id: pytest
name: pytest
entry: uv run pytest
language: system
types: [python]
pass_filenames: false
always_run: true
9 changes: 0 additions & 9 deletions environment.yml

This file was deleted.

4 changes: 2 additions & 2 deletions notebooks/01_fingerprinting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"from sklearn.decomposition import PCA\n",
"from drfp import DrfpEncoder\n"
"from drfp import DrfpEncoder"
]
},
{
Expand Down Expand Up @@ -288,7 +288,7 @@
"pca = PCA(n_components=2)\n",
"X = pca.fit(fps).transform(fps)\n",
"\n",
"plt.scatter(X[:,0], X[:,1], alpha=0.8)\n",
"plt.scatter(X[:, 0], X[:, 1], alpha=0.8)\n",
"plt.title(\"PCA of 100 drfp-encoded reactions\")"
]
},
Expand Down
4 changes: 2 additions & 2 deletions notebooks/02_model_explainability.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@
}
],
"source": [
"shap.force_plot(explainer.expected_value, shap_values[0,:], matplotlib=True)"
"shap.force_plot(explainer.expected_value, shap_values[0, :], matplotlib=True)"
]
},
{
Expand Down Expand Up @@ -321,7 +321,7 @@
}
],
"source": [
"shap.force_plot(explainer.expected_value, shap_values[42,:], matplotlib=True)"
"shap.force_plot(explainer.expected_value, shap_values[42, :], matplotlib=True)"
]
},
{
Expand Down
8 changes: 3 additions & 5 deletions notebooks/03_more_model_explainability.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@
}
],
"source": [
"\n",
"%pip install theia-pypi xgboost matplotlib faerun-notebook --upgrade\n",
"import pickle\n",
"from pathlib import Path\n",
Expand Down Expand Up @@ -321,6 +320,7 @@
"!pip uninstall ipywidgets -y\n",
"!pip install ipywidgets==7.7.1\n",
"import ipywidgets\n",
"\n",
"ipywidgets.version_info"
]
},
Expand Down Expand Up @@ -375,11 +375,9 @@
"mapping = split[\"test\"][\"mapping\"]\n",
"dataset = InferenceReactionDataset([rxn])\n",
"\n",
"expl = explain_regression(\n",
" dataset, explainer, mapping\n",
")\n",
"expl = explain_regression(dataset, explainer, mapping)\n",
"\n",
"w = { \"reactants\": expl.reactant_weights, \"products\": expl.product_weights}\n",
"w = {\"reactants\": expl.reactant_weights, \"products\": expl.product_weights}\n",
"\n",
"SmilesDrawer(value=[(\"Example\", rxn)], weights=[w], output=\"img\", theme=\"solarized\")"
]
Expand Down
95 changes: 39 additions & 56 deletions notebooks/0a_figures.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@
"\n",
"pd.options.mode.chained_assignment = None\n",
"\n",
"modern_cmap = LinearSegmentedColormap.from_list(\n",
" \"modern_cmap\", \n",
" [\"#ffffff\", \"#003f5c\"], \n",
" N=256\n",
")\n",
"modern_cmap = LinearSegmentedColormap.from_list(\"modern_cmap\", [\"#ffffff\", \"#003f5c\"], N=256)\n",
"\n",
"schneider_class_names = [\n",
" \"Alductive amination\",\n",
Expand Down Expand Up @@ -124,9 +120,7 @@
}
],
"source": [
"df = pd.read_csv(\"yield_prediction_results.csv\", names=[\n",
" \"data_set\", \"split\", \"filename\", \"ground_truth\", \"prediction\"\n",
"])\n",
"df = pd.read_csv(\"yield_prediction_results.csv\", names=[\"data_set\", \"split\", \"filename\", \"ground_truth\", \"prediction\"])\n",
"\n",
"df[\"error\"] = df.prediction - df.ground_truth\n",
"\n",
Expand Down Expand Up @@ -181,6 +175,7 @@
" fontfamily=font_family,\n",
" )\n",
"\n",
"\n",
"def calc_r2(df, verbose=True):\n",
" result = []\n",
" result_raw = []\n",
Expand All @@ -196,34 +191,31 @@
" r2 = r2_score(df_tmp.ground_truth, df_tmp.prediction)\n",
" result_raw.append({\"r2\": r2, \"split\": split})\n",
" r2s.append(r2)\n",
" \n",
"\n",
" if verbose:\n",
" print(f\"r2 mean={round(sum(r2s) / len(r2s), 5)}, r2 std={round(stdev(r2s), 5)}\")\n",
" result.append((sum(r2s) / len(r2s), stdev(r2s)))\n",
"\n",
" return (result, pd.DataFrame(result_raw))\n",
"\n",
"\n",
"def scatter(df, ax, title):\n",
" sns.kdeplot(\n",
" data=df, \n",
" x=\"ground_truth\", y=\"prediction\", clip=((0, 100), (None, None)),\n",
" color=\"#003f5c\", levels=6, zorder=2, ax=ax\n",
" data=df, x=\"ground_truth\", y=\"prediction\", clip=((0, 100), (None, None)), color=\"#003f5c\", levels=6, zorder=2, ax=ax\n",
" )\n",
"\n",
" ax.plot(\n",
" [0, 100], [0, 100], linewidth=2, \n",
" color=\"#bc5090\", linestyle=\"dashed\",\n",
" [0, 100],\n",
" [0, 100],\n",
" linewidth=2,\n",
" color=\"#bc5090\",\n",
" linestyle=\"dashed\",\n",
" zorder=1,\n",
" )\n",
"\n",
" sns.scatterplot(\n",
" data=df, \n",
" x=\"ground_truth\", y=\"prediction\",\n",
" color=\"#cccccc\", linewidth=0,\n",
" alpha=0.125, zorder=0, ax=ax\n",
" )\n",
" sns.scatterplot(data=df, x=\"ground_truth\", y=\"prediction\", color=\"#cccccc\", linewidth=0, alpha=0.125, zorder=0, ax=ax)\n",
"\n",
" ax.set(xlabel=\"Ground Truth\", ylabel='Prediction')\n",
" ax.set(xlabel=\"Ground Truth\", ylabel=\"Prediction\")\n",
" ax.set_title(title)"
]
},
Expand Down Expand Up @@ -329,36 +321,28 @@
"splits = [98, 197, 395, 791, 1186, 1977, 2766]\n",
"\n",
"for i in range(7):\n",
" scatter(\n",
" df_buchwald_hartwig_cv[df_buchwald_hartwig_cv.split == splits[i]],\n",
" axs.flat[i], titles[i]\n",
" )\n",
" scatter(df_buchwald_hartwig_cv[df_buchwald_hartwig_cv.split == splits[i]], axs.flat[i], titles[i])\n",
"\n",
"_, df_results = calc_r2(df_buchwald_hartwig_cv, verbose=False)\n",
"\n",
"sns.stripplot(\n",
" x=\"split\", y=\"r2\", data=df_results, linewidth=1, ax=axs.flat[7],\n",
" palette=[\"#003f5c\", \"#374c80\", \"#7a5195\", \"#bc5090\", \"#ef5675\", \"#ff764a\", \"#ffa600\"]\n",
" x=\"split\",\n",
" y=\"r2\",\n",
" data=df_results,\n",
" linewidth=1,\n",
" ax=axs.flat[7],\n",
" palette=[\"#003f5c\", \"#374c80\", \"#7a5195\", \"#bc5090\", \"#ef5675\", \"#ff764a\", \"#ffa600\"],\n",
")\n",
"axs.flat[7].set_xticklabels([\"a\", \"b\", \"c\", \"d\", \"e\", \"f\", \"g\"])\n",
"axs.flat[7].set(xlabel=\"Split\", ylabel=\"Accuracy\")\n",
"\n",
"titles = [\n",
" \"Out-of-sample Split 1\", \"Out-of-sample Split 2\", \n",
" \"Out-of-sample Split 3\", \"Out-of-sample Split 4\"\n",
"]\n",
"titles = [\"Out-of-sample Split 1\", \"Out-of-sample Split 2\", \"Out-of-sample Split 3\", \"Out-of-sample Split 4\"]\n",
"\n",
"splits = [\n",
" \"Test1-2048-3-true.pkl\", \"Test2-2048-3-true.pkl\", \n",
" \"Test3-2048-3-true.pkl\", \"Test4-2048-3-true.pkl\"\n",
"]\n",
"splits = [\"Test1-2048-3-true.pkl\", \"Test2-2048-3-true.pkl\", \"Test3-2048-3-true.pkl\", \"Test4-2048-3-true.pkl\"]\n",
"\n",
"j = 0\n",
"for i in range(8, 12):\n",
" scatter(\n",
" df_buchwald_hartwig_tests[df_buchwald_hartwig_tests.split == splits[j]],\n",
" axs.flat[i], titles[j]\n",
" )\n",
" scatter(df_buchwald_hartwig_tests[df_buchwald_hartwig_tests.split == splits[j]], axs.flat[i], titles[j])\n",
" j += 1\n",
"\n",
"index_subplots(axs.flat, font_size=14, y=1.17)\n",
Expand Down Expand Up @@ -401,7 +385,7 @@
"\n",
"plt_cm = []\n",
"for i in cm.classes:\n",
" row=[]\n",
" row = []\n",
" for j in cm.classes:\n",
" row.append(cm.table[i][j])\n",
" plt_cm.append(row)\n",
Expand All @@ -414,9 +398,7 @@
"\n",
"\n",
"sns.heatmap(\n",
" plt_cm, cmap=\"RdPu\", linewidths=.1, linecolor=\"#eeeeee\", square=True, \n",
" cbar_kws={\"shrink\": 0.5}, norm=LogNorm(),\n",
" ax=ax\n",
" plt_cm, cmap=\"RdPu\", linewidths=0.1, linecolor=\"#eeeeee\", square=True, cbar_kws={\"shrink\": 0.5}, norm=LogNorm(), ax=ax\n",
")\n",
"\n",
"cax = plt.gcf().axes[-1]\n",
Expand Down Expand Up @@ -491,8 +473,17 @@
"y.extend(y_train)\n",
"y.extend(y_test)\n",
"\n",
"labels = {\"1\": \"Heteroatom alkylation and arylation\", \"2\": \"Acylation and related processes\", \"3\": \"C-C bond formation\", \"5\": \"Protections\", \"6\": \"Deprotections\",\n",
" \"7\": \"Reductions\", \"8\": \"Oxidations\", \"9\": \"Functional group interconversion (FGI)\", \"10\": \"Functional group addition (FGA)\"}\n",
"labels = {\n",
" \"1\": \"Heteroatom alkylation and arylation\",\n",
" \"2\": \"Acylation and related processes\",\n",
" \"3\": \"C-C bond formation\",\n",
" \"5\": \"Protections\",\n",
" \"6\": \"Deprotections\",\n",
" \"7\": \"Reductions\",\n",
" \"8\": \"Oxidations\",\n",
" \"9\": \"Functional group interconversion (FGI)\",\n",
" \"10\": \"Functional group addition (FGA)\",\n",
"}\n",
"\n",
"y_values = [labels[ytem.split(\".\")[0]] for ytem in y]\n",
"\n",
Expand Down Expand Up @@ -535,24 +526,16 @@
" \"#595959\",\n",
" \"#5f9ed1\",\n",
" \"#c85300\",\n",
" #\"#898989\",\n",
" # \"#898989\",\n",
" \"#a2c8ec\",\n",
" \"#ffbc79\",\n",
" \"#cfcfcf\"\n",
" \"#cfcfcf\",\n",
"]\n",
"\n",
"df_tmap = pd.DataFrame({\"x\": x, \"y\": y, \"c\": y_values})\n",
"sns.scatterplot(x=\"x\", y=\"y\", hue=\"c\", data=df_tmap, s=5.0, palette=palette, ax=ax, zorder=2)\n",
"\n",
"legend = ax.legend(\n",
" loc=\"center left\", \n",
" bbox_to_anchor=(1, 0.5),\n",
" fancybox=False, \n",
" shadow=False, \n",
" frameon=False,\n",
" ncol=1,\n",
" fontsize=7\n",
")\n",
"legend = ax.legend(loc=\"center left\", bbox_to_anchor=(1, 0.5), fancybox=False, shadow=False, frameon=False, ncol=1, fontsize=7)\n",
"\n",
"for handle in legend.legendHandles:\n",
" handle.set_sizes([12.0])\n",
Expand Down
Loading

0 comments on commit e394ea2

Please sign in to comment.