diff --git a/.gitignore b/.gitignore index dafb4a21b..eb302a491 100644 --- a/.gitignore +++ b/.gitignore @@ -90,6 +90,7 @@ target/ # IDEs (VSCode, etc.) .env *.code-workspace +.vscode/ # Class diagram *.pyns diff --git a/.vscode/.ropeproject/config.py b/.vscode/.ropeproject/config.py deleted file mode 100644 index c339bc796..000000000 --- a/.vscode/.ropeproject/config.py +++ /dev/null @@ -1,123 +0,0 @@ -# The default ``config.py`` -# flake8: noqa - - -def set_prefs(prefs): - """This function is called before opening the project""" - - # Specify which files and folders to ignore in the project. - # Changes to ignored resources are not added to the history and - # VCSs. Also they are not returned in `Project.get_files()`. - # Note that ``?`` and ``*`` match all characters but slashes. - # '*.pyc': matches 'test.pyc' and 'pkg/test.pyc' - # 'mod*.pyc': matches 'test/mod1.pyc' but not 'mod/1.pyc' - # '.svn': matches 'pkg/.svn' and all of its children - # 'build/*.o': matches 'build/lib.o' but not 'build/sub/lib.o' - # 'build//*.o': matches 'build/lib.o' and 'build/sub/lib.o' - prefs["ignored_resources"] = [ - "*.pyc", - "*~", - ".ropeproject", - ".hg", - ".svn", - "_svn", - ".git", - ".tox", - ] - - # Specifies which files should be considered python files. It is - # useful when you have scripts inside your project. Only files - # ending with ``.py`` are considered to be python files by - # default. - # prefs['python_files'] = ['*.py'] - - # Custom source folders: By default rope searches the project - # for finding source folders (folders that should be searched - # for finding modules). You can add paths to that list. Note - # that rope guesses project source folders correctly most of the - # time; use this if you have any problems. - # The folders should be relative to project root and use '/' for - # separating folders regardless of the platform rope is running on. - # 'src/my_source_folder' for instance. - # prefs.add('source_folders', 'src') - - # You can extend python path for looking up modules - # prefs.add('python_path', '~/python/') - - # Should rope save object information or not. - prefs["save_objectdb"] = True - prefs["compress_objectdb"] = False - - # If `True`, rope analyzes each module when it is being saved. - prefs["automatic_soa"] = True - # The depth of calls to follow in static object analysis - prefs["soa_followed_calls"] = 0 - - # If `False` when running modules or unit tests "dynamic object - # analysis" is turned off. This makes them much faster. - prefs["perform_doa"] = True - - # Rope can check the validity of its object DB when running. - prefs["validate_objectdb"] = True - - # How many undos to hold? - prefs["max_history_items"] = 32 - - # Shows whether to save history across sessions. - prefs["save_history"] = True - prefs["compress_history"] = False - - # Set the number spaces used for indenting. According to - # :PEP:`8`, it is best to use 4 spaces. Since most of rope's - # unit-tests use 4 spaces it is more reliable, too. - prefs["indent_size"] = 4 - - # Builtin and c-extension modules that are allowed to be imported - # and inspected by rope. - prefs["extension_modules"] = [] - - # Add all standard c-extensions to extension_modules list. - prefs["import_dynload_stdmods"] = True - - # If `True` modules with syntax errors are considered to be empty. - # The default value is `False`; When `False` syntax errors raise - # `rope.base.exceptions.ModuleSyntaxError` exception. - prefs["ignore_syntax_errors"] = False - - # If `True`, rope ignores unresolvable imports. Otherwise, they - # appear in the importing namespace. - prefs["ignore_bad_imports"] = False - - # If `True`, rope will insert new module imports as - # `from import ` by default. - prefs["prefer_module_from_imports"] = False - - # If `True`, rope will transform a comma list of imports into - # multiple separate import statements when organizing - # imports. - prefs["split_imports"] = False - - # If `True`, rope will remove all top-level import statements and - # reinsert them at the top of the module when making changes. - prefs["pull_imports_to_top"] = True - - # If `True`, rope will sort imports alphabetically by module name instead - # of alphabetically by import statement, with from imports after normal - # imports. - prefs["sort_imports_alphabetically"] = False - - # Location of implementation of - # rope.base.oi.type_hinting.interfaces.ITypeHintingFactory In general - # case, you don't have to change this value, unless you're an rope expert. - # Change this value to inject you own implementations of interfaces - # listed in module rope.base.oi.type_hinting.providers.interfaces - # For example, you can add you own providers for Django Models, or disable - # the search type-hinting in a class hierarchy, etc. - prefs["type_hinting_factory"] = ( - "rope.base.oi.type_hinting.factory.default_type_hinting_factory" - ) - - -def project_opened(project): - """This function is called after opening the project""" - # Do whatever you like here! diff --git a/.vscode/autodocstring.template b/.vscode/autodocstring.template deleted file mode 100644 index cf934b72b..000000000 --- a/.vscode/autodocstring.template +++ /dev/null @@ -1,28 +0,0 @@ -{{! Google Docstring Template without types }} -{{summaryPlaceholder}} - -{{extendedSummaryPlaceholder}} - -{{#parametersExist}} -Args: -{{#args}} - {{var}}: {{descriptionPlaceholder}} -{{/args}} -{{#kwargs}} - {{var}}: {{descriptionPlaceholder}} -{{/kwargs}} -{{/parametersExist}} - -{{#returnsExist}} -Returns: -{{#returns}} - {{typePlaceholder}}: {{descriptionPlaceholder}} -{{/returns}} -{{/returnsExist}} - -{{#yieldsExist}} -Yields: -{{#yields}} - {{typePlaceholder}}: {{descriptionPlaceholder}} -{{/yields}} -{{/yieldsExist}} diff --git a/.vscode/extensions.json b/.vscode/extensions.json deleted file mode 100644 index 8faab311b..000000000 --- a/.vscode/extensions.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - // See https://go.microsoft.com/fwlink/?LinkId=827846 to learn about workspace recommendations. - // Extension identifier format: ${publisher}.${name}. Example: vscode.csharp - // List of extensions which should be recommended for users of this workspace. - "recommendations": [ - "stkb.rewrap", - "ms-pyright.pyright", - "eamodio.gitlens", - "ms-vsliveshare.vsliveshare", - "ms-vsliveshare.vsliveshare-audio", - "karigari.chat", - "njpwerner.autodocstring", - "oijaz.unicode-latex", - "donjayamanne.githistory" - ], - // List of extensions recommended by VS Code that should not be recommended for users of this workspace. - "unwantedRecommendations": [] -} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 484b73acb..000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,76 +0,0 @@ -{ - "python.testing.unittestEnabled": false, - "python.testing.nosetestsEnabled": false, - "python.testing.pytestEnabled": true, - "python.testing.pytestArgs": [ - "tests" - ], - // - // Editor settings for python - // - "[python]": { - "editor.defaultFormatter": "ms-python.black-formatter", - "editor.formatOnSave": true, - "editor.codeActionsOnSave": { - "source.sortImports": "explicit" - }, // OK as long as isort 3rd parties managed in setup.cfg?, see below - "editor.wordWrapColumn": 88, - "editor.renderWhitespace": "boundary", - "editor.wordWrap": "wordWrapColumn", - "editor.fontLigatures": true, - "editor.rulers": [ - 88 - ], - "rewrap.wholeComment": false, - "rewrap.doubleSentenceSpacing": true, - // - // Please add the rewrap vscode plugin (stkb.rewrap) - // Still pending https://github.com/stkb/Rewrap/issues/88 for full python docstring support - "rewrap.autoWrap.enabled": true, - "rewrap.wrappingColumn": 88, - }, - "workbench.colorCustomizations": { - "editorRuler.foreground": "#444444" - }, - // - // Formatting - // https://code.visualstudio.com/docs/python/editing#_formatting - // - "python.formatting.provider": "none", - "python.formatting.blackArgs": [ - "--line-length=88" - ], - // - // Sort imports - // https://github.com/microsoft/vscode-python/issues/5840#issuecomment-497321419 - // - "python.sortImports.args": [ - "--settings-path=${workspaceFolder}/setup.cfg" - ], - // - // Linting - // https://code.visualstudio.com/docs/python/linting#_specific-linters - // - "python.linting.enabled": true, - "python.linting.lintOnSave": true, - "python.linting.flake8Enabled": true, - "python.linting.args": [ - "--settings-path=${workspaceFolder}/setup.cfg" - ], - // - // Files to exclude - // - "files.exclude": { - "**/__pycache__": true, - "**/.pytest_cache": true, - "**/.ipynb_checkpoints": true, - "**/*.egg-info": true, - }, - // - // Please add the autoDocstring vscode plugin (njpwerner.autodocstring) - // - "autoDocstring.customTemplatePath": ".vscode/autodocstring.template", - // - // Signal that we are using shared workspace settings in .vscode - "window.title": "sbi.vscode:: ${dirty}${activeEditorShort}${separator}${rootName}${separator}${appName}", -} \ No newline at end of file diff --git a/.vscode/snippets.code-snippets b/.vscode/snippets.code-snippets deleted file mode 100644 index ce24abe56..000000000 --- a/.vscode/snippets.code-snippets +++ /dev/null @@ -1,36 +0,0 @@ -{ - "Annotations from future": { - "prefix": "annf", - "body": [ - "from __future__ import annotations", - "" - ], - "description": "Annotations from future", - "scope": "python" - }, - "Short copyright notice": { - "prefix": "copyr", - "body": [ - "# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed", - "# under the Affero General Public License v3, see .", - "" - ], - "description": "Short copyright notice" - } - // Place your sbi workspace snippets here. Each snippet is defined under a snippet name and has a scope, prefix, body and - // description. Add comma separated ids of the languages where the snippet is applicable in the scope field. If scope - // is left empty or omitted, the snippet gets applied to all languages. The prefix is what is - // used to trigger the snippet and the body will be expanded and inserted. Possible variables are: - // $1, $2 for tab stops, $0 for the final cursor position, and ${1:label}, ${2:another} for placeholders. - // Placeholders with the same ids are connected. - // Example: - // "Print to console": { - // "scope": "javascript,typescript", - // "prefix": "log", - // "body": [ - // "console.log('$1');", - // "$2" - // ], - // "description": "Log output to console" - // } -} \ No newline at end of file diff --git a/environment.yml b/environment.yml deleted file mode 100644 index 91f7eca18..000000000 --- a/environment.yml +++ /dev/null @@ -1,36 +0,0 @@ -# -# Create: -# $ conda env create --prefix .sbi_env --file environment.yml -# -# Update: -# $ conda env update --prefix .sbi_env --file environment.yml --prune -# -# Activate: -# $ conda activate ~/path/to/sbi/.sbi_env -# -# Nicer prompt (adds to ~/.condarc): -# $ conda config --set env_prompt '({name}) ' -# -name: sbi_env - -channels: - - conda-forge - - pytorch - -dependencies: - - arviz - - cudatoolkit - - jupyter - - jupyterlab - - matplotlib - - notebook - - pillow - - pip - - pip: - - "pyknos>=0.14.2" - - "pyro-ppl>=1.3.1" - - -e ".[dev]" - - "python >= 3.6.0" - - "pytorch >= 1.8.0" - - scikit-learn - - scipy diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..ca75649f4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,123 @@ +[build-system] +requires = [ + "setuptools>=65", "wheel" +] +build-backend = "setuptools.build_meta" + +[project] +name = "sbi" +description = "Simulation-based inference." +authors = [ + { name = "sbi-dev", email = "simulation.based.inference@gmail.com"}, +] +classifiers = [ + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "Topic :: Adaptive Technologies", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", + """License :: OSI Approved :: GNU Affero General Public License v3 or later + (AGPLv3+)""", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Development Status :: 3 - Alpha", +] +requires-python = ">=3.6" +dynamic = ["version"] +readme = "README.md" +keywords = ["Bayesian inference", "simulation-based inference", "PyTorch"] +dependencies = [ + "arviz", + "joblib>=1.0.0", + "matplotlib", + "numpy", + "pillow", + "pyknos>=0.15.1", + "pyro-ppl>=1.3.1", + "scikit-learn", + "scipy", + "tensorboard", + "torch>=1.8.0", + "tqdm", +] + +[project.optional-dependencies] +dev = [ + "autoflake", + "black", + "deepdiff", + "flake8", + "isort", + "jupyter", + "mkdocs", + "mkdocs-material", + "markdown-include", + "mkdocs-redirects", + "mkdocstrings[python] >= 0.18", + "nbconvert", + "pre-commit", + "pytest", + "pyyaml", + "pyright >=1.1.300,<1.1.306", + "torchtestcase", + "twine", +] + +[project.urls] +documentation = "https://sbi-dev.github.io/sbi/" +source = "https://github.com/sbi-dev/sbi" +tracker = "https://github.com/sbi-dev/sbi/issues" + +[tool.black] +line-length = 88 +target-version = ['py37', 'py38', 'py39', 'py310', 'py311'] +include = '\.pyi?$|\.ipynb$' +extend-exclude = ''' +/( + \.git + | \.venv + | \.ipynb_checkpoints +)/ +''' + +[tool.isort] +line_length = 88 +include_trailing_comma = true +use_parentheses = true +skip_glob = [".ipynb_checkpoints", "docs/*"] +known_first_party=["sbi", "tests", "examples", "tutorials"] +known_third_party = ["arviz", "joblib", "matplotlib", "numpy", "pyknos", "pyro", "pytest", "scipy", "six", "sklearn", "tensorboard", "torch", "torchtestcase", "tqdm", "typing_extensions"] +multi_line_output = 3 + +[tool.flake8] +max-line-length = 88 +exclude = [ + "docs", + "build", + "dist", + ".ipynb_checkpoints" +] + +# Pytest configuration +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-ra -q" +testpaths = [ + "tests", +] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "gpu: marks tests that require a gpu (deselect with '-m \"not gpu\"')" +] + +[tool.setuptools.packages.find] +where = ["."] # list of folders that contain the packages (["."] by default) +include = ["sbi*"] # package names should match these glob patterns (["*"] by default) +exclude = ["sbi-logs*"] # exclude packages matching these glob patterns (empty by default) +namespaces = false # to disable scanning PEP 420 namespaces (true by default) + +[tool.setuptools.dynamic] +version = {attr = "sbi.__version__"} diff --git a/sbi/inference/__init__.py b/sbi/inference/__init__.py index dcece8f7c..8dc9f50d5 100644 --- a/sbi/inference/__init__.py +++ b/sbi/inference/__init__.py @@ -13,12 +13,8 @@ from sbi.inference.abc.mcabc import MCABC from sbi.inference.abc.smcabc import SMCABC -from sbi.inference.base import ( # noqa: F401 - NeuralInference, - check_if_proposal_has_default_x, - infer, - simulate_for_sbi, -) +from sbi.inference.base import NeuralInference # noqa: F401 +from sbi.inference.base import check_if_proposal_has_default_x, infer, simulate_for_sbi from sbi.inference.snle.mnle import MNLE from sbi.inference.snle.snle_a import SNLE_A from sbi.inference.snpe.snpe_a import SNPE_A diff --git a/sbi/neural_nets/classifier.py b/sbi/neural_nets/classifier.py index 31bb7a8e8..4000f86ff 100644 --- a/sbi/neural_nets/classifier.py +++ b/sbi/neural_nets/classifier.py @@ -89,7 +89,7 @@ def build_linear_classifier( z_score_y: Optional[str] = "independent", embedding_net_x: nn.Module = nn.Identity(), embedding_net_y: nn.Module = nn.Identity(), - **kwargs + **kwargs, ) -> nn.Module: """Builds linear classifier. diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 8f610a8da..000000000 --- a/setup.cfg +++ /dev/null @@ -1,21 +0,0 @@ -[metadata] -description-file = README.md - -[tool:pytest] -markers = - slow: marks tests as slow (deselect with '-m "not slow"') - gpu: marks tests that require a gpu (deselect with '-m "not gpu"') - -[flake8] -max-line-length = 88 -exclude = docs, build, dist, .ipynb_checkpoints - -[isort] -line_length = 88 -include_trailing_comma=True -force_grid_wrap=0 -use_parentheses=True -skip_glob=.ipynb_checkpoints -known_first_party=sbi,tests -known_third_party=arviz,joblib,matplotlib,numpy,pyknos,pyro,pytest,scipy,setuptools,six,sklearn,tensorboard,torch,torchtestcase,tqdm,typing_extensions -multi_line_output=3 diff --git a/setup.py b/setup.py deleted file mode 100644 index a96541085..000000000 --- a/setup.py +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed -# under the Affero General Public License v3, see . -# -# Note: To use the 'upload' functionality of this file, you must: -# $ pipenv install twine --dev - -import io -import os -import sys -from shutil import rmtree - -from setuptools import Command, find_packages, setup - -# Package meta-data. -NAME = "sbi" -DESCRIPTION = "Simulation-based inference." -KEYWORDS = "bayesian parameter inference system_identification simulator PyTorch" -URL = "https://github.com/sbi-dev/sbi" -EMAIL = "simulation.based.inference@gmail.com" -AUTHOR = """Álvaro Tejero-Cantero, Jakob H. Macke, Jan-Matthis Lückmann, Conor M. - Durkan, Michael Deistler, Jan Bölts""" -REQUIRES_PYTHON = ">=3.6.0" - -REQUIRED = [ - "arviz", - "joblib>=1.0.0", - "matplotlib", - "numpy", - "pillow", - "pyknos>=0.15.1", - "pyro-ppl>=1.3.1", - "scikit-learn", - "scipy", - "tensorboard", - "torch>=1.8.0", - "tqdm", -] - -EXTRAS = { - "dev": [ - "autoflake", - "black", - "deepdiff", - "flake8", - "isort", - "jupyter", - "mkdocs", - "mkdocs-material", - "markdown-include", - "mkdocs-redirects", - "mkdocstrings[python]>=0.18", - "nbconvert", - "pep517", - "pre-commit", - "pytest", - "pyyaml", - "pyright>=1.1.300,<1.1.306", - "torchtestcase", - "twine", - ], -} - -here = os.path.abspath(os.path.dirname(__file__)) - -# Import the README and use it as the long-description. -try: - with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: - long_description = "\n" + f.read() -except FileNotFoundError: - long_description = DESCRIPTION - -# Load the package's __version__.py module as a dictionary. -about = {} -project_slug = NAME.lower().replace("-", "_").replace(" ", "_") -with open(os.path.join(here, project_slug, "__version__.py")) as f: - exec(f.read(), about) - - -class UploadCommand(Command): - """Support setup.py upload.""" - - description = "Build and publish the package." - user_options = [] - - @staticmethod - def status(s): - """Prints things in bold.""" - print("\033[1m{0}\033[0m".format(s)) - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - try: - self.status("Removing previous builds…") - rmtree(os.path.join(here, "dist")) - except OSError: - pass - - self.status("Building Source and Wheel (universal) distribution…") - os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) - - self.status("Uploading the package to PyPI via Twine…") - os.system("twine upload dist/*") - - self.status("Pushing git tags…") - os.system("git tag v{0}".format(about["__version__"])) - os.system("git push --tags") - - sys.exit() - - -setup( - name=NAME, - version=about["__version__"], - description=DESCRIPTION, - keywords=KEYWORDS, - long_description=long_description, - long_description_content_type="text/markdown", - author=AUTHOR, - author_email=EMAIL, - python_requires=REQUIRES_PYTHON, - url=URL, - packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), - install_requires=REQUIRED, - extras_require=EXTRAS, - include_package_data=True, - license="AGPLv3", - classifiers=[ - # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers - "Development Status :: 3 - Alpha", - """License :: OSI Approved :: GNU Affero General Public License v3 or later - (AGPLv3+)""", - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - "Topic :: Adaptive Technologies", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Scientific/Engineering :: Mathematics", - # TODO: Update python support? - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - ], - # $ setup.py publish support. - cmdclass=dict(upload=UploadCommand), -)