diff --git a/.github/envs/environment.yml b/.github/envs/environment.yml new file mode 100644 index 0000000..27e17ae --- /dev/null +++ b/.github/envs/environment.yml @@ -0,0 +1,20 @@ +name: test-environment +channels: + - conda-forge +dependencies: + - dask >=2025 + - pandas + - polars + - pyspark + - pyarrow >=15 + - numpy + - pytest + - pytest-cov + - numba + - awkward + - distributed + - openjdk ==20 + - pip + - pip: + - ray[data] + - git+https://github.com/dask-contrib/dask-awkward diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 914f094..8868691 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -18,20 +18,24 @@ jobs: fail-fast: false matrix: platform: [ubuntu-latest, macos-latest, windows-latest] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] runs-on: ${{matrix.platform}} steps: - name: Checkout uses: actions/checkout@v3 - - name: setup Python ${{matrix.python-version}} - uses: actions/setup-python@v4 + - name: Setup Conda Environment + uses: conda-incubator/setup-miniconda@v3 with: - python-version: ${{matrix.python-version}} + python-version: ${{ matrix.python-version }} + environment-file: .github/envs/environment.yml + activate-environment: test-environment - name: install + shell: bash -l {0} run: | pip install pip wheel -U - pip install -q --no-cache-dir .[test] + pip install -q --no-cache-dir -e .[test] pip list - name: test + shell: bash -l {0} run: | python -m pytest -v --cov-config=.coveragerc --cov akimbo diff --git a/docs/demo/akimbo-demo.ipynb b/docs/akimbo-demo.ipynb similarity index 99% rename from docs/demo/akimbo-demo.ipynb rename to docs/akimbo-demo.ipynb index 7b992c1..845e889 100644 --- a/docs/demo/akimbo-demo.ipynb +++ b/docs/akimbo-demo.ipynb @@ -1,5 +1,15 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "8b1be0e8", + "metadata": {}, + "source": [ + "# HEP Demo\n", + "\n", + "Here we show a plausible small workflow on a real excerpt of particle data." + ] + }, { "cell_type": "code", "execution_count": 1, diff --git a/docs/api.rst b/docs/api.rst index a2282f8..e48575c 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1,21 +1,6 @@ akimbo ============== -.. currentmodule:: akimbo - -Top Level Functions -~~~~~~~~~~~~~~~~~~~ - -.. autosummary:: - :toctree: generated/ - - read_parquet - read_json - read_avro - get_parquet_schema - get_json_schema - get_avro_schema - Accessor ~~~~~~~~ @@ -38,6 +23,8 @@ Backends akimbo.dask.DaskAwkwardAccessor akimbo.polars.PolarsAwkwardAccessor akimbo.cudf.CudfAwkwardAccessor + akimbo.ray.RayAccessor + akimbo.spark.SparkAccessor .. autoclass:: akimbo.pandas.PandasAwkwardAccessor @@ -47,6 +34,25 @@ Backends .. autoclass:: akimbo.cudf.CudfAwkwardAccessor +.. autoclass:: akimbo.ray.RayAccessor + +.. autoclass:: akimbo.spark.SparkAccessor + +Top Level Functions +~~~~~~~~~~~~~~~~~~~ +.. currentmodule:: akimbo + + +.. autosummary:: + :toctree: generated/ + + read_parquet + read_json + read_avro + get_parquet_schema + get_json_schema + get_avro_schema + Extensions ~~~~~~~~~~ diff --git a/docs/conf.py b/docs/conf.py index 461cf3c..ebdb78c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -24,13 +24,7 @@ ] templates_path = ["_templates"] -exclude_patterns = [ - "_build", - "Thumbs.db", - ".DS_Store", - "**.ipynb_checkpoints", - "**akimbo-demo.ipynb", -] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] # -- Options for HTML output ------------------------------------------------- diff --git a/docs/cudf-ak.ipynb b/docs/cudf-ak.ipynb new file mode 120000 index 0000000..8f765e8 --- /dev/null +++ b/docs/cudf-ak.ipynb @@ -0,0 +1 @@ +example/cudf-ak.ipynb \ No newline at end of file diff --git a/docs/demo/.gitignore b/docs/demo/.gitignore deleted file mode 100644 index 4bed5da..0000000 --- a/docs/demo/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.parquet diff --git a/example/cuda_env.yaml b/docs/example/cuda_env.yaml similarity index 100% rename from example/cuda_env.yaml rename to docs/example/cuda_env.yaml diff --git a/example/cudf-ak.ipynb b/docs/example/cudf-ak.ipynb similarity index 97% rename from example/cudf-ak.ipynb rename to docs/example/cudf-ak.ipynb index f786c4e..8f3e938 100644 --- a/example/cudf-ak.ipynb +++ b/docs/example/cudf-ak.ipynb @@ -1,10 +1,19 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "ee00a3e2", + "metadata": {}, + "source": [ + "# GPU backend" + ] + }, { "cell_type": "markdown", "id": "58d18a3a-45b1-425a-b822-e8be0a6c0bc0", "metadata": {}, "source": [ + "This example depends on data in a file that can be made in the following way.\n", "\n", "```python\n", "import awkward as ak\n", @@ -14,6 +23,11 @@ " [[6, 7]]] * N\n", " arr = ak.Array({\"a\": part})\n", " ak.to_parquet(arr, fn, extensionarray=False)\n", + "```\n", + "\n", + "The file cuda-env.yaml can be used to create a functional environment using conda:\n", + "```bash\n", + "$ conda env create -f example/cuda-env.yaml\n", "```" ] }, @@ -617,7 +631,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.0" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/docs/index.rst b/docs/index.rst index c3daebd..1ec24f4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -24,6 +24,8 @@ identical syntax: - dask.dataframe - polars - cuDF +- ray dataset +- pyspark numpy-like API @@ -111,6 +113,13 @@ the ``akimbo`` system, you can apply these methods to ragged/nested dataframes. install.rst quickstart.ipynb +.. toctree:: + :maxdepth: 1 + :caption: Demos + + akimbo-demo.ipynb + cudf-ak.ipynb + .. toctree:: :maxdepth: 1 :caption: API Reference diff --git a/docs/install.rst b/docs/install.rst index 6c05a4f..50fbd75 100644 --- a/docs/install.rst +++ b/docs/install.rst @@ -5,7 +5,11 @@ Requirements ~~~~~~~~~~~~ To install ``akimbo`` you will need ``awkward`` and -one of the backend libraries: ``pandas``, ``dask`` or ``polars``. +one of the backend libraries: ``pandas``, ``dask``, ``cuDF``, ``ray.data``, +``pyspark`` or ``polars``. Each of there have various installation options, +please see their respective documentation. + +``akimbo`` depends on ``pyarrow`` and ``awkward``. From PyPI diff --git a/docs/demo/muons_dataset1.svg b/docs/muons_dataset1.svg similarity index 100% rename from docs/demo/muons_dataset1.svg rename to docs/muons_dataset1.svg diff --git a/docs/demo/muons_dataset_df.svg b/docs/muons_dataset_df.svg similarity index 100% rename from docs/demo/muons_dataset_df.svg rename to docs/muons_dataset_df.svg diff --git a/src/akimbo/apply_tree.py b/src/akimbo/apply_tree.py index a6531a7..76a4d07 100644 --- a/src/akimbo/apply_tree.py +++ b/src/akimbo/apply_tree.py @@ -80,7 +80,11 @@ def dec( match: function to determine if a part of the data structure matches the type we want to operate on outtype: postprocessing function after transform - inmode: how ``func`` expects its inputs: as awkward arrays (ak), numpy or arrow + inmode: how ``func`` expects its inputs: as + - ak: awkward arrays, + - numpy + - arrow + - other: anything that can be cast to ak arrays, e.g., number literals """ @functools.wraps(func) diff --git a/src/akimbo/cudf.py b/src/akimbo/cudf.py index b5b66dd..13cc3e2 100644 --- a/src/akimbo/cudf.py +++ b/src/akimbo/cudf.py @@ -107,6 +107,12 @@ def f(lay, method=meth, **kwargs): class CudfAwkwardAccessor(Accessor): + """Operations on cuDF dataframes on the GPU. + + Data are kept in GPU memory and use views rather than copies where + possible. + """ + series_type = Series dataframe_type = DataFrame @@ -145,9 +151,17 @@ def str(self): try: cast = dec_cu(libcudf.unary.cast, match=leaf) except AttributeError: + def cast_inner(col, dtype): - return cudf.core.column.ColumnBase(col.data, size=len(col), dtype=np.dtype(dtype), - mask=None, offset=0, children=()) + return cudf.core.column.ColumnBase( + col.data, + size=len(col), + dtype=np.dtype(dtype), + mask=None, + offset=0, + children=(), + ) + cast = dec_cu(cast_inner, match=leaf) @property diff --git a/src/akimbo/dask.py b/src/akimbo/dask.py index 0046add..1ac974c 100644 --- a/src/akimbo/dask.py +++ b/src/akimbo/dask.py @@ -69,7 +69,7 @@ def run(self, *args, **kwargs): ar = [self._to_tt(ar) if hasattr(ar, "ak") else ar for ar in ar] out = op(tt, *ar, **kwargs) meta = PandasAwkwardAccessor._to_output( - ak.typetracer.length_zero_if_typetracer(out) + ak.typetracer.length_one_if_typetracer(out) ) except (ValueError, TypeError): meta = None diff --git a/src/akimbo/datetimes.py b/src/akimbo/datetimes.py index 9f78239..226db43 100644 --- a/src/akimbo/datetimes.py +++ b/src/akimbo/datetimes.py @@ -24,7 +24,7 @@ def __init__(self, accessor) -> None: floor_temporal = dec_t(pc.floor_temporal) reound_temporal = dec_t(pc.round_temporal) strftime = dec_t(pc.strftime) - strptime = dec_t(pc.strptime) + # strptime = dec_t(pc.strptime) # this is in .str instead day = dec_t(pc.day) day_of_week = dec_t(pc.day_of_week) day_of_year = dec_t(pc.day_of_year) diff --git a/src/akimbo/io.py b/src/akimbo/io.py index afaf496..cd32765 100644 --- a/src/akimbo/io.py +++ b/src/akimbo/io.py @@ -6,6 +6,7 @@ def ak_to_series(ds, backend="pandas", extract=True): + """Make backend-specific series from data""" if backend == "pandas": import akimbo.pandas @@ -23,6 +24,9 @@ def ak_to_series(ds, backend="pandas", extract=True): import akimbo.cudf s = akimbo.cudf.CudfAwkwardAccessor._to_output(ds) + elif backend in ["ray", "spark"]: + raise ValueError("Backend only supports dataframes, not series") + else: raise ValueError("Backend must be in {'pandas', 'polars', 'dask'}") if extract and ds.fields: @@ -30,6 +34,8 @@ def ak_to_series(ds, backend="pandas", extract=True): return s +# TODO: read_parquet should use native versions rather than convert. This version +# is OK for pandas def read_parquet( url: str, storage_options: dict | None = None, @@ -60,6 +66,8 @@ def read_parquet( return ak_to_series(ds, backend, extract=extract) +# TODO: should be a map over input files, maybe with newline byte blocks +# as in dask def read_json( url: str, storage_options: dict | None = None, @@ -124,6 +132,8 @@ def get_json_schema( return layout_to_jsonschema(arr.layout) +# TODO: should be a map over input files, maybe with newline byte blocks +# as in dask def read_avro( url: str, storage_options: dict | None = None, @@ -205,9 +215,9 @@ def join( merge = _merge counts = np.empty(len(table1), dtype="uint64") - # TODO: the line below over-allocates, can switch to somehing growable + # TODO: the line below over-allocates, can switch to something growable matches = np.empty(len(table2), dtype="uint64") - # TODO: to_numpy(allow_missong) makes this a bit faster, but is not + # TODO: to_numpy(allow_missing) makes this a bit faster, but is not # not GPU general counts, matches, ind = merge(table1[key], table2[key], counts, matches) matches.resize(int(ind), refcheck=False) diff --git a/src/akimbo/mixin.py b/src/akimbo/mixin.py index 6ee3928..eaa346e 100644 --- a/src/akimbo/mixin.py +++ b/src/akimbo/mixin.py @@ -5,6 +5,7 @@ from typing import Callable, Iterable import awkward as ak +import numpy as np import pyarrow.compute as pc from akimbo.apply_tree import dec, match_any, numeric, run_with_transform @@ -82,14 +83,6 @@ class ArithmeticMixin: def _create_op(cls, op): raise AbstractMethodError(cls) - @classmethod - def _create_op(cls, op): - raise AbstractMethodError(cls) - - @classmethod - def _create_op(cls, op): - raise AbstractMethodError(cls) - @classmethod def _add_arithmetic_ops(cls) -> None: setattr(cls, "__add__", cls._create_op(operator.add)) @@ -158,7 +151,7 @@ def accessor(self): @classmethod def is_series(cls, data): - return isinstance(data, cls.series_type) + return isinstance(data, cls.series_type) if cls.series_type else False @classmethod def is_dataframe(cls, data): @@ -210,6 +203,9 @@ def transform( This process walks thought the data's schema tree, and applies the given function only on the matching nodes. + The function input(s) and output depend on inmode and outttpe + arguments. + Parameters ---------- fn: the operation you want to perform. Typically unary or binary, and may take @@ -228,10 +224,20 @@ def transform( bits = tuple(where.split(".")) if isinstance(where, str) else where arr = self.array part = arr.__getitem__(bits) - # TODO: apply ``where`` to any arrays in others - # other = [to_ak_layout(ar) for ar in others] + others = ( + _ + if isinstance(_, (str, int, float, np.number)) + else to_ak_layout(_).__getitem__(bits) + for _ in others + ) + callkwargs = { + k: _ + if isinstance(_, (str, int, float, np.number)) + else to_ak_layout(_).__getitem__(bits) + for k, _ in kwargs.items() + } out = run_with_transform( - part, fn, match=match, others=others, inmode=inmode, **kwargs + part, fn, match=match, others=others, inmode=inmode, **callkwargs ) final = ak.with_field(arr, out, where=where) else: @@ -247,7 +253,7 @@ def __getitem__(self, item): def __dir__(self) -> Iterable[str]: attrs = (_ for _ in dir(self.array) if not _.startswith("_")) meths = series_methods if self.is_series(self._obj) else df_methods - return sorted(set(attrs) | set(meths)) + return sorted(set(attrs) | set(meths) | set(self.subaccessors)) def with_behavior(self, behavior, where=()): """Assign a behavior to this array-of-records""" @@ -270,10 +276,33 @@ def with_behavior(self, behavior, where=()): def __array_function__(self, *args, **kwargs): return self.array.__array_function__(*args, **kwargs) - def __array_ufunc__(self, *args, **kwargs): - if args[1] == "__call__": - return self.to_output(args[0](self.array, *args[3:], **kwargs)) - raise NotImplementedError + def __array_ufunc__(self, *args, where=None, out=None, **kwargs): + # includes operator overload like df.ak + 1 + ufunc, call, inputs, *callargs = args + if out is not None or call != "__call__": + raise NotImplementedError + if where: + # called like np.add(df.ak, 1, where="...") + bits = tuple(where.split(".")) if isinstance(where, str) else where + arr = self.array + part = arr.__getitem__(bits) + callargs = ( + _ + if isinstance(_, (str, int, float, np.number)) + else to_ak_layout(_).__getitem__(bits) + for _ in callargs + ) + callkwargs = { + k: _ + if isinstance(_, (str, int, float, np.number)) + else to_ak_layout(_).__getitem__(bits) + for k, _ in kwargs.items() + } + + out = self.to_output(ufunc(part, *callargs, **callkwargs)) + return self.to_output(ak.with_field(arr, out, where=where)) + + return self.to_output(ufunc(self.array, *callargs, **kwargs)) @property def arrow(self) -> ak.Array: @@ -293,7 +322,7 @@ def array(self) -> ak.Array: return ak.from_arrow(self.arrow) @classmethod - def register_accessor(cls, name, klass): + def register_accessor(cls, name: str, klass: type): # TODO: check clobber? cls.subaccessors[name] = klass @@ -413,12 +442,14 @@ def op2(*args, extra=None, **kw): args = list(args) + list(extra or []) return op(*args, **kw) - def f(self, *args, **kw): + def f(self, *args, where=None, **kw): # TODO: test here is for literals, but really we want "don't know how to # array that" condition - extra = (_ for _ in args if isinstance(_, (str, int, float))) + extra = [_ for _ in args if isinstance(_, (str, int, float, np.number))] args = ( - to_ak_layout(_) for _ in args if not isinstance(_, (str, int, float)) + to_ak_layout(_) + for _ in args + if not isinstance(_, (str, int, float, np.number)) ) out = self.transform( op2, @@ -427,6 +458,7 @@ def f(self, *args, **kw): inmode="numpy", extra=extra, outtype=ak.contents.NumpyArray, + where=where, **kw, ) if isinstance(self._obj, self.dataframe_type): diff --git a/src/akimbo/polars.py b/src/akimbo/polars.py index db20ab9..cdb2b65 100644 --- a/src/akimbo/polars.py +++ b/src/akimbo/polars.py @@ -1,12 +1,19 @@ +from typing import Callable, Dict + import polars as pl +import pyarrow as pa +from akimbo.apply_tree import match_any from akimbo.mixin import Accessor @pl.api.register_series_namespace("ak") @pl.api.register_dataframe_namespace("ak") class PolarsAwkwardAccessor(Accessor): - """Perform awkward operations on a polars series or dataframe""" + """Perform awkward operations on a polars series or dataframe + + This is for *eager* operations. A Lazy version may eventually be made. + """ series_type = pl.Series dataframe_type = pl.DataFrame @@ -22,3 +29,120 @@ def to_arrow(cls, data): def pack(self): # polars already implements this directly return self._obj.to_struct() + + +@pl.api.register_lazyframe_namespace +class LazyPolarsAwkwardAccessor(Accessor): + dataframe_type = pl.LazyFrame + series_type = None # lazy is never series + + def transform( + self, fn: Callable, *others, where=None, match=match_any, inmode="ak", **kwargs + ): + # TODO determine schema from first-run, with df.collect_schema() + return pl.map_batches( + (self._obj,) + others, + lambda d: d.ak.transform( + fn, match=match, inmode=inmode, **kwargs + ).ak.unpack(), + schema=None, + ) + + +def arrow_to_polars_type(arrow_type: pa.DataType) -> pl.DataType: + type_mapping = { + pa.int8(): pl.Int8, + pa.int16(): pl.Int16, + pa.int32(): pl.Int32, + pa.int64(): pl.Int64, + pa.uint8(): pl.UInt8, + pa.uint16(): pl.UInt16, + pa.uint32(): pl.UInt32, + pa.uint64(): pl.UInt64, + pa.float32(): pl.Float32, + pa.float64(): pl.Float64, + pa.string(): pl.String, + pa.bool_(): pl.Boolean, + } + + if arrow_type in type_mapping: + return type_mapping[arrow_type] + + # parametrised types + if pa.types.is_timestamp(arrow_type): + return pl.Datetime(time_unit=arrow_type.unit, time_zone=arrow_type.tx) + + if pa.types.is_decimal(arrow_type): + return pl.Decimal(precision=arrow_type.precision, scale=arrow_type.scale) + + # Handle list type + if pa.types.is_list(arrow_type): + value_type = arrow_to_polars_type(arrow_type.value_type) + return pl.List(value_type) + + # Handle struct type + if pa.types.is_struct(arrow_type): + fields = {} + for field in arrow_type: + fields[field.name] = arrow_to_polars_type(field.type) + return pl.Struct(fields) + + raise ValueError(f"Unsupported Arrow type: {arrow_type}") + + +def polars_to_arrow_type(polars_type: pl.DataType) -> pa.DataType: + type_mapping = { + pl.Int8: pa.int8(), + pl.Int16: pa.int16(), + pl.Int32: pa.int32(), + pl.Int64: pa.int64(), + pl.UInt8: pa.uint8(), + pl.UInt16: pa.uint16(), + pl.UInt32: pa.uint32(), + pl.UInt64: pa.uint64(), + pl.Float32: pa.float32(), + pl.Float64: pa.float64(), + pl.String: pa.string(), + pl.Boolean: pa.bool_(), + pl.Date: pa.date32(), + } + + if polars_type in type_mapping: + return type_mapping[polars_type] + + # parametrised types + if isinstance(polars_type, pl.DataType): + return pa.timestamp(polars_type.unit, polars_type.time_zone) + + if isinstance(polars_type, pl.Decimal): + return pa.decimal128(polars_type.precision, polars_type.scale) + + # Handle list type + if isinstance(polars_type, pl.List): + value_type = polars_to_arrow_type(polars_type.inner) + return pa.list_(value_type) + + # Handle struct type + if isinstance(polars_type, pl.Struct): + fields = [] + for name, dtype in polars_type.fields.items(): + arrow_type = polars_to_arrow_type(dtype) + fields.append(pa.field(name, arrow_type)) + return pa.struct(fields) + + raise ValueError(f"Unsupported Polars type: {polars_type}") + + +def arrow_to_polars_schema(arrow_schema: pa.Schema) -> Dict[str, pl.DataType]: + polars_schema = {} + for field in arrow_schema: + polars_schema[field.name] = arrow_to_polars_type(field.type) + return polars_schema + + +def polars_to_arrow_schema(polars_schema: Dict[str, pl.DataType]) -> pa.Schema: + fields = [] + for name, dtype in polars_schema.items(): + arrow_type = polars_to_arrow_type(dtype) + fields.append(pa.field(name, arrow_type)) + return pa.schema(fields) diff --git a/src/akimbo/ray.py b/src/akimbo/ray.py new file mode 100644 index 0000000..f96bd35 --- /dev/null +++ b/src/akimbo/ray.py @@ -0,0 +1,264 @@ +import functools +from typing import Callable, Iterable + +import awkward as ak +import numpy as np +import pyarrow as pa +import ray +import ray.data as rd + +from akimbo.apply_tree import run_with_transform +from akimbo.datetimes import DatetimeAccessor +from akimbo.datetimes import match as match_dt +from akimbo.mixin import Accessor, match_any, numeric +from akimbo.strings import StringAccessor, match_string, strptime +from akimbo.utils import to_ak_layout + + +class RayStringAccessor(StringAccessor): + def __init__(self, *_): + pass + + def __getattr__(self, attr: str) -> callable: + attr = self.method_name(attr) + return getattr(ak.str, attr) + + @property + def strptime(self): + @functools.wraps(strptime) + def run(*arrs, **kwargs): + arr, *other = arrs + return run_with_transform(arr, strptime, match_string, **kwargs) + + return run + + +class RayDatetimeAccessor: + def __init__(self, *_): + pass + + def __getattr__(self, item): + if item in dir(DatetimeAccessor): + fn = getattr(DatetimeAccessor, item) + if hasattr(fn, "__wrapped__"): + func = fn.__wrapped__ # arrow function + else: + raise AttributeError + else: + raise AttributeError + + @functools.wraps(func) + def run(*arrs, **kwargs): + arr, *other = arrs + return run_with_transform(arr, func, match_dt, **kwargs) + + return run + + def __dir__(self): + return dir(DatetimeAccessor) + + +class RayAccessor(Accessor): + """Operations on ray.data.Dataset dataframes. + + This is a lazy backend, and operates partition-wise. It predicts the schema + of each operation by running with an empty dataframe of the correct type. + """ + + dataframe_type = rd.Dataset + series_type = None # only has "dataframe like" + subaccessors = Accessor.subaccessors.copy() + + def __init__(self, obj, subaccessor=None, behavior=None): + super().__init__(obj, behavior) + self.subaccessor = subaccessor + + def to_arrow(self, data: rd.Dataset) -> pa.Table: + batches = ray.get(data.to_arrow_refs()) + return pa.concat_tables(batches) + + def to_output(self, data=None): + import pandas as pd + + data = self.to_arrow(data if data is not None else self._obj) + data = data.to_pandas(types_mapper=pd.ArrowDtype) + if list(data.columns) == ["_ak_series_"]: + data = data["_ak_series_"] + return data + + def __getattr__(self, item: str) -> rd.Dataset: + if isinstance(item, str) and item in self.subaccessors: + return RayAccessor(self._obj, subaccessor=item, behavior=self._behavior) + + def select(*inargs, subaccessor=self.subaccessor, where=None, **kwargs): + if subaccessor: + func0 = getattr(self.subaccessors[subaccessor](), item) + elif callable(item): + func0 = item + else: + func0 = None + + def f(batch): + arr = ak.from_arrow(batch) + if any(isinstance(_, str) and _ == "_ak_other_" for _ in inargs): + # binary input + other = arr[[_ for _ in arr.fields if _.startswith("_df2_")]] + # 5 == len("_df2_"); rename to original fields + other.layout._fields[:] = [k[5:] for k in other.fields] + arr = arr[[_ for _ in arr.fields if not _.startswith("_df2_")]] + if other.fields == ["_ak_series_"]: + other = other["_ak_series_"] + if where is not None: + other = other[where] + inargs0 = [other if str(_) == "_ak_other_" else _ for _ in inargs] + else: + inargs0 = inargs + other = None + if where: + arr0 = arr + arr = arr[where] + if arr.fields == ["_ak_series_"]: + arr = arr["_ak_series_"] + + if callable(func0): + func = func0 + args = (arr,) + elif hasattr(arr, item) and callable(getattr(arr, item)): + func = getattr(arr, item) + args = () + elif subaccessor: + func = func0 + args = (arr,) + elif hasattr(ak, item): + func = getattr(ak, item) + args = (arr,) + else: + raise KeyError(item) + + out = func(*args, *inargs0, **kwargs) + if where: + out = ak.with_field(arr0, out, where) + if not out.layout.fields: + out = ak.Array({"_ak_series_": out}) + return ak.to_arrow_table( + out, + extensionarray=False, + list_to32=True, + string_to32=True, + bytestring_to32=True, + ) + + f.__name__ = item.__name__ if callable(item) else item + + inargs = [_._obj if isinstance(_, type(self)) else _ for _ in inargs] + n_others = sum(isinstance(_, self.dataframe_type) for _ in inargs) + if n_others == 1: + other = next(_ for _ in inargs if isinstance(_, self.dataframe_type)) + inargs = [ + "_ak_other_" if isinstance(_, self.dataframe_type) else _ + for _ in inargs + ] + obj = concat_columns_zip_index(self._obj, other) + elif n_others > 1: + raise NotImplementedError + else: + obj = self._obj + arrow_type = obj.schema().base_schema + arr = pa.table([[]] * len(arrow_type), schema=arrow_type) + out1 = f(arr) + out_schema = pa.table(out1).schema + result = obj.map_batches(f, zero_copy_batch=True, batch_format="pyarrow") + result._plan.cache_schema(out_schema) + return result + + return select + + def __array_ufunc__(self, *args, where=None, out=None, **kwargs): + # includes operator overload like df.ak + 1 + ufunc, call, inputs, *callargs = args + if out is not None or call != "__call__": + raise NotImplementedError + + return self.__getattr__(ufunc)(*callargs, where=where, **kwargs) + + def __getitem__(self, item) -> rd.dataset: + def f(batch): + arr = ak.from_arrow(batch) + arr2 = arr.__getitem__(item) + if not arr2.fields: + arr2 = ak.Array({"_ak_series_": arr2}) + return ak.to_arrow_table( + arr2, + extensionarray=False, + list_to32=True, + string_to32=True, + bytestring_to32=True, + ) + + arrow_type = self._obj.schema().base_schema + if not isinstance(arrow_type, pa.Schema): + # TODO: fix, for data via from_pandas or from_numpy + raise ValueError("Use arrow types") + arr = pa.table([[]] * len(arrow_type), schema=arrow_type) + out1 = f(arr) + out_schema = pa.table(out1).schema + result = self._obj.map_batches(f, zero_copy_batch=True, batch_format="pyarrow") + # this is what .schema(fetch_if_missing=True) does, but we already know + # the value without compute + result._plan.cache_schema(out_schema) + return result + + def transform( + self, + fn: callable, + *others, + where=None, + match=match_any, + inmode="array", + **kwargs, + ): + def f(arr, *others, **kwargs): + return run_with_transform( + arr, fn, match=match, others=others, inmode=inmode, **kwargs + ) + + return self.__getattr__(f)(*others, **kwargs) + + def apply(self, fn: Callable, *others, where=None, **kwargs): + return self.__getattr__(fn)(*others, **kwargs) + + @classmethod + def _create_op(cls, op): + def run(self, *args, **kwargs): + args = [ + to_ak_layout(_) if isinstance(_, (str, int, float, np.number)) else _ + for _ in args + ] + return self.transform(op, *args, match=numeric) + + return run + + def __dir__(self) -> Iterable[str]: + if self.subaccessor is not None: + return dir(self.subaccessors[self.subaccessor](self)) + return super().__dir__() + + +RayAccessor.register_accessor("dt", RayDatetimeAccessor) +RayAccessor.register_accessor("str", RayStringAccessor) + + +def concat_columns_zip_index(df1: rd.Dataset, df2: rd.Dataset) -> rd.Dataset: + """Add two DataFrames' columns into a single DF.""" + if df1.num_blocks != df2.num_blocks: + # warn that this causes shuffle + pass + return df1.zip(df2.rename_columns({k: f"_df2_{k}" for k in df2.columns()})) + + +@property # type:ignore +def ak_property(self): + return RayAccessor(self) + + +rd.Dataset.ak = ak_property # Ray has no Series diff --git a/src/akimbo/spark.py b/src/akimbo/spark.py new file mode 100644 index 0000000..99f60cd --- /dev/null +++ b/src/akimbo/spark.py @@ -0,0 +1,266 @@ +import functools +from typing import Callable, Iterable + +import awkward as ak +import numpy as np +import pyarrow as pa +import pyspark +from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema + +from akimbo.apply_tree import run_with_transform +from akimbo.datetimes import DatetimeAccessor +from akimbo.datetimes import match as match_dt +from akimbo.mixin import Accessor, match_any, numeric +from akimbo.pandas import pd +from akimbo.strings import StringAccessor, match_string, strptime +from akimbo.utils import to_ak_layout + +sdf = pyspark.sql.DataFrame + + +class SparkStringAccessor(StringAccessor): + def __init__(self, *_): + pass + + def __getattr__(self, attr: str) -> callable: + attr = self.method_name(attr) + return getattr(ak.str, attr) + + @property + def strptime(self): + @functools.wraps(strptime) + def run(*arrs, **kwargs): + arr, *other = arrs + return run_with_transform(arr, strptime, match_string, **kwargs) + + return run + + +class SparkDatetimeAccessor: + def __init__(self, *_): + pass + + def __getattr__(self, item): + if item in dir(DatetimeAccessor): + fn = getattr(DatetimeAccessor, item) + if hasattr(fn, "__wrapped__"): + func = fn.__wrapped__ # arrow function + else: + raise AttributeError + else: + raise AttributeError + + @functools.wraps(func) + def run(*arrs, **kwargs): + arr, *other = arrs + return run_with_transform(arr, func, match_dt, **kwargs) + + return run + + def __dir__(self): + return dir(DatetimeAccessor) + + +class SparkAccessor(Accessor): + """Operations on pyspark dataframes. + + This is a lazy backend, and operates partition-wise. It predicts the schema + of each operation by running with an empty dataframe of the correct type. + """ + + subaccessors = Accessor.subaccessors.copy() + dataframe_type = sdf + + def __init__(self, obj, subaccessor=None, behavior=None): + super().__init__(obj, behavior) + self.subaccessor = subaccessor + + def to_arrow(self, data) -> pa.Table: + # collects data locally + batches = data._collect_as_arrow() + return pa.Table.from_batches(batches) + + def to_output(self, data=None) -> pd.DataFrame | pd.Series: + # data is always arrow format internally + data = self.to_arrow(data if data is not None else self._obj).to_pandas( + types_mapper=pd.ArrowDtype + ) + if list(data.columns) == ["_ak_series_"]: + data = data["_ak_series_"] + return data + + def __getattr__(self, item: str) -> sdf: + if isinstance(item, str) and item in self.subaccessors: + return SparkAccessor(self._obj, subaccessor=item, behavior=self._behavior) + + def select(*inargs, subaccessor=self.subaccessor, where=None, **kwargs): + if subaccessor: + func0 = getattr(self.subaccessors[subaccessor](), item) + elif callable(item): + func0 = item + else: + func0 = None + + def f(batches): + for batch in batches: + arr = ak.from_arrow(batch) + if any(isinstance(_, str) and _ == "_ak_other_" for _ in inargs): + # binary input + arr, other = arr["_1"], arr["_df2"] + if other.fields == ["_ak_series_"]: + other = other["_ak_series_"] + if where is not None: + other = other[where] + inargs0 = [ + other if str(_) == "_ak_other_" else _ for _ in inargs + ] + else: + inargs0 = inargs + other = None + if where: + arr0 = arr + arr = arr[where] + if arr.fields == ["_ak_series_"]: + arr = arr["_ak_series_"] + + if callable(func0): + func = func0 + args = (arr,) + elif hasattr(arr, item) and callable(getattr(arr, item)): + func = getattr(arr, item) + args = () + elif subaccessor: + func = func0 + args = (arr,) + elif hasattr(ak, item): + func = getattr(ak, item) + args = (arr,) + else: + raise KeyError(item) + + out = func(*args, *inargs0, **kwargs) + if where: + out = ak.with_field(arr0, out, where) + if not out.layout.fields: + out = ak.Array({"_ak_series_": out}) + arrout = ak.to_arrow( + out, + extensionarray=False, + list_to32=True, + string_to32=True, + bytestring_to32=True, + ) + yield pa.RecordBatch.from_struct_array(arrout) + + f.__name__ = item.__name__ if callable(item) else item + + inargs = [_._obj if isinstance(_, type(self)) else _ for _ in inargs] + n_others = sum(isinstance(_, self.dataframe_type) for _ in inargs) + if n_others == 1: + other = next(_ for _ in inargs if isinstance(_, self.dataframe_type)) + inargs = [ + "_ak_other_" if isinstance(_, self.dataframe_type) else _ + for _ in inargs + ] + obj = concat_columns_zip_index(self._obj, other) + elif n_others > 1: + raise NotImplementedError + else: + obj = self._obj + arrow_type = to_arrow_schema(obj.schema) + arr = pa.table([[]] * len(arrow_type), schema=arrow_type) + out1 = next(f([arr])) + out_schema = pa.table(out1).schema + return obj.mapInArrow(f, schema=from_arrow_schema(out_schema)) + + return select + + def __array_ufunc__(self, *args, where=None, out=None, **kwargs): + # includes operator overload like df.ak + 1 + ufunc, call, inputs, *callargs = args + if out is not None or call != "__call__": + raise NotImplementedError + + return self.__getattr__(ufunc)(*callargs, where=where, **kwargs) + + def __getitem__(self, item) -> sdf: + def f(batches): + for batch in batches: + arr = ak.from_arrow(batch) + arr2 = arr.__getitem__(item) + if not arr2.fields: + arr2 = ak.Array({"_ak_series_": arr2}) + out = ak.to_arrow( + arr2, + extensionarray=False, + list_to32=True, + string_to32=True, + bytestring_to32=True, + ) + yield pa.RecordBatch.from_struct_array(out) + + arrow_type = to_arrow_schema(self._obj.schema) + arr = pa.table([[]] * len(arrow_type), schema=arrow_type) + out1 = next(f([arr])) + out_schema = pa.table(out1).schema + return self._obj.mapInArrow(f, schema=from_arrow_schema(out_schema)) + + def transform( + self, + fn: callable, + *others, + where=None, + match=match_any, + inmode="array", + **kwargs, + ): + def f(arr, *others, **kwargs): + return run_with_transform( + arr, fn, match=match, others=others, inmode=inmode, **kwargs + ) + + return self.__getattr__(f)(*others, **kwargs) + + def apply(self, fn: Callable, *others, where=None, **kwargs): + return self.__getattr__(fn)(*others, **kwargs) + + @classmethod + def _create_op(cls, op): + def run(self, *args, **kwargs): + args = [ + to_ak_layout(_) if isinstance(_, (str, int, float, np.number)) else _ + for _ in args + ] + return self.transform(op, *args, match=numeric) + + return run + + def __dir__(self) -> Iterable[str]: + if self.subaccessor is not None: + return dir(self.subaccessors[self.subaccessor](self)) + return super().__dir__() + + +SparkAccessor.register_accessor("dt", SparkDatetimeAccessor) +SparkAccessor.register_accessor("str", SparkStringAccessor) + + +def concat_columns_zip_index(df1: sdf, df2: sdf) -> sdf: + """Add two DataFrames' columns into a single DF. + + The is SQL-tricky, but at least it requires no python map/iteration! + """ + if df1.rdd.getNumPartitions() != df2.rdd.getNumPartitions(): + # warn that this causes shuffle + pass + df1_ind = df1.rdd.zipWithIndex().toDF() + df2_ind = df2.rdd.zipWithIndex().toDF().withColumnRenamed("_1", "_df2") + return df1_ind.join(df2_ind, "_2", "left").sort("_2").drop("_2") + + +@property # type:ignore +def ak_property(self): + return SparkAccessor(self) + + +pyspark.sql.DataFrame.ak = ak_property # spark has no Series diff --git a/src/akimbo/strings.py b/src/akimbo/strings.py index 1460154..a614049 100644 --- a/src/akimbo/strings.py +++ b/src/akimbo/strings.py @@ -1,7 +1,6 @@ from __future__ import annotations import functools -from collections.abc import Callable import awkward as ak import pyarrow.compute as pc @@ -48,12 +47,15 @@ def _decode(layout): if not aname.startswith(("_", "akstr_")) and not aname[0].isupper() ] -# make sensible defaults for strptime -strptime = functools.wraps(pc.strptime)( - lambda *args, format="%FT%T", unit="s", error_is_null=True, **kw: pc.strptime( - *args, format=format, unit=unit, error_is_null=error_is_null + +@functools.wraps(pc.strptime) +def strptime(*args, format="%FT%T", unit="us", error_is_null=True, **kw): + """strptime with typical defaults set to reverse strftime""" + out = pc.strptime( + *args, format=format, unit=unit, error_is_null=error_is_null, **kw ) -) + print(args[0], out, file=open("out", "w")) + return out class StringAccessor: @@ -91,7 +93,7 @@ def decode(self, encoding: str = "utf-8"): def method_name(attr: str) -> str: return _SA_METHODMAPPING.get(attr, attr) - def __getattr__(self, attr: str) -> Callable: + def __getattr__(self, attr: str) -> callable: attr = self.method_name(attr) fn = getattr(ak.str, attr) diff --git a/tests/test_dask.py b/tests/test_dask.py index 1c92bbe..b6ca8c8 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -26,7 +26,7 @@ def test_accessor(): s = pd.Series(data) df = pd.DataFrame({"s": s}) ddf = dd.from_pandas(df, 2) - out = ddf.s.ak.count() + out = ddf.s.ak.count() # causes dask warning, as each partition reduces to scalar assert out.compute().tolist() == [3, 3] out = ddf.s.ak.count(axis=1).compute() diff --git a/tests/test_dt.py b/tests/test_dt.py index 9ae16ea..4a78f4b 100644 --- a/tests/test_dt.py +++ b/tests/test_dt.py @@ -78,4 +78,6 @@ def test_text_conversion(): s = pd.Series([["2024-08-01T01:00:00", None, "2024-08-01T01:01:00"]]) s2 = s.ak.str.strptime() s3 = s2.ak.dt.strftime("%FT%T") - assert s3.tolist() == [["2024-08-01T01:00:00", None, "2024-08-01T01:01:00"]] + # remove trailing zeros - depends on system defaults + out = [None if _ is None else _.split(".")[0] for _ in s3.tolist()[0]] + assert out == ["2024-08-01T01:00:00", None, "2024-08-01T01:01:00"] diff --git a/tests/test_ray.py b/tests/test_ray.py new file mode 100644 index 0000000..c920507 --- /dev/null +++ b/tests/test_ray.py @@ -0,0 +1,152 @@ +import sys + +import awkward as ak +import numpy as np +import pytest + +WIN = sys.platform.startswith("win") +pd = pytest.importorskip("pandas") +ray = pytest.importorskip("ray") + +pytest.importorskip("akimbo.pandas") +pytest.importorskip("akimbo.ray") + +x = pd.Series([[[1, 2, 3], [], [3, 4, None]], [None]] * 100).ak.to_output() +y = pd.Series([["hey", None], ["hi", "ho"]] * 100).ak.to_output() + + +@pytest.fixture(scope="module") +def rayc(): + context = ray.init() + yield context + context.disconnect() + + +@pytest.fixture() +def df(rayc, tmpdir): + import ray.data + + pd.DataFrame({"x": x, "y": y}).to_parquet(f"{tmpdir}/a.parquet") + return ray.data.read_parquet(f"{tmpdir}/a.parquet", override_num_blocks=2) + + +def test_unary(df): + out = df.ak.is_none() + result = out.ak.to_output() + expected = x.ak.is_none() + assert result.x.tolist() == expected.tolist() + + out = df.ak.is_none(axis=1) + result = out.ak.to_output() + expected = x.ak.is_none(axis=1) + assert result.x.tolist() == expected.tolist() + + out = df.ak.str.upper() + result = out.ak.to_output() + expected = y.ak.str.upper() + assert result.y.tolist() == expected.tolist() + + +@pytest.mark.skipif(WIN, reason="may not have locale on windows") +def test_dt(rayc): + data = pd.DataFrame( + { + "_ak_series_": pd.Series( + pd.date_range(start="2024-01-01", end="2024-01-02", freq="h") + ) + } + ) + df = ray.data.from_arrow(data.ak.arrow) + out = df.ak.dt.strftime() + result = out.ak.to_output() + + assert result[0] == "2024-01-01T00:00:00.000000000" # defaults to ns + + out = df.ak.dt.hour() + result = out.ak.to_output() + assert set(result) == set(range(24)) + + +def test_select(df): + out = df.ak["x"] + result = out.ak.to_output() + assert isinstance(result, pd.Series) + assert result.tolist() == x.tolist() + out = df.ak[:, ::2] + result = out.ak.to_output() + expected = x.ak[:, ::2] + assert result.x.tolist() == expected.tolist() + + +def test_binary(df): + out = df.ak["x"].ak.isclose(1) + result = out.ak.to_output() + expected = x.ak.isclose(1) + assert result.tolist() == expected.tolist() + + out = df.ak["x"].ak.isclose(df.ak["x"]) + result = out.ak.to_output() + expected = x.ak.isclose(x) + assert result.tolist() == expected.tolist() + + +def test_ufunc(df): + out = np.negative(df.ak["x"].ak) + result = out.ak.to_output() + expected = np.negative(x.ak) + assert result.tolist() == expected.tolist() + + out = np.add(df.ak["x"].ak, df.ak["x"]) + result = out.ak.to_output() + expected = x.ak * 2 + assert result.tolist() == expected.tolist() + + +def test_ufunc_where(df): + out = np.add(df.ak, df, where="x") + result = out.ak.to_output() + expected = x.ak * 2 + assert result.x.tolist() == expected.tolist() + + +def test_overload(rayc): + x = pd.Series([1, 2, 3]) + df = ray.data.from_arrow(pd.DataFrame(x, columns=["_ak_series_"]).ak.arrow) + + out = df.ak + 1 # scalar + result = out.ak.to_output() + assert result.tolist() == [2, 3, 4] + + out = df.ak + [1] # array-like, broadcast + result = out.ak.to_output() + assert result.tolist() == [2, 3, 4] + + out = df.ak == df.ak # matching layout + result = out.ak.to_output() + assert result.tolist() == [True, True, True] + + +def test_dir(df): + assert "flatten" in dir(df.ak) + assert "upper" in dir(df.ak.str) + + +def test_apply_numba(df): + numba = pytest.importorskip("numba") + + @numba.njit() + def f(data: ak.Array, builder: ak.ArrayBuilder) -> None: + for i, item in enumerate(data.x): + if item[0] is None: + builder.append(None) + else: + builder.append(item[0][2] + item[2][0]) # always 6 + + def f2(data): + builder = ak.ArrayBuilder() + f(data, builder) + return builder.snapshot() + + out = df.ak.apply(f2, where="x") + result = out.ak.to_output() + assert result.ak.tolist() == [6, None] * 100 diff --git a/tests/test_spark.py b/tests/test_spark.py index 2a9a7f0..ddf3dec 100644 --- a/tests/test_spark.py +++ b/tests/test_spark.py @@ -1,19 +1,169 @@ +import os +import sys + +import awkward as ak +import numpy as np import pytest +WIN = sys.platform.startswith("win") pd = pytest.importorskip("pandas") pyspark = pytest.importorskip("pyspark") -import akimbo.spark + +pytest.importorskip("akimbo.pandas") +pytest.importorskip("akimbo.spark") + +x = pd.Series([[[1, 2, 3], [], [3, 4, None]], [None]] * 100).ak.to_output() +y = pd.Series([["hey", None], ["hi", "ho"]] * 100).ak.to_output() @pytest.fixture(scope="module") def spark(): from pyspark.sql import SparkSession - return SparkSession.builder.appName("test").getOrCreate() + os.environ["PYSPARK_PYTHON"] = sys.executable + os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable + + return ( + SparkSession.builder + # .config("spark.sql.execution.arrow.enabled", "true") # this was spark<3.0.0 + .config("spark.sql.execution.pythonUDF.arrow.enabled", "true") + .config("spark.sql.execution.arrow.pyspark.enabled", "true") + .config("spark.sql.execution.arrow.pyspark.fallback.enabled", "false") + .appName("test") + .getOrCreate() + ) + + +@pytest.fixture() +def df(spark, tmpdir): + pd.DataFrame({"x": x, "y": y}).to_parquet(f"{tmpdir}/a.parquet") + return spark.read.parquet(f"{tmpdir}/a.parquet") + + +def test_unary(df): + out = df.ak.is_none() + result = out.ak.to_output() + expected = x.ak.is_none() + assert result.x.tolist() == expected.tolist() + + out = df.ak.is_none(axis=1) + result = out.ak.to_output() + expected = x.ak.is_none(axis=1) + assert result.x.tolist() == expected.tolist() + + out = df.ak.str.upper() + result = out.ak.to_output() + expected = y.ak.str.upper() + assert result.y.tolist() == expected.tolist() + + +@pytest.mark.skipif(WIN, reason="may not have locale on windows") +def test_dt(spark): + data = ( + pd.DataFrame( + { + "_ak_series_": pd.Series( + pd.date_range(start="2024-01-01", end="2024-01-02", freq="h") + ) + } + ) + .ak.to_output() + .ak.unpack() + ) + df = spark.createDataFrame(data) + out = df.ak.dt.strftime() + result = out.ak.to_output() + + assert result[0] == "2024-01-01T00:00:00.000000" + + out = df.ak.dt.hour() + result = out.ak.to_output() + assert set(result) == set(range(24)) + + +def test_select(df): + out = df.ak["x"] + result = out.ak.to_output() + assert isinstance(result, pd.Series) + assert result.tolist() == x.tolist() + out = df.ak[:, ::2] + result = out.ak.to_output() + expected = x.ak[:, ::2] + assert result.x.tolist() == expected.tolist() + +def test_binary(df): + out = df.ak["x"].ak.isclose(1) + result = out.ak.to_output() + expected = x.ak.isclose(1) + assert result.tolist() == expected.tolist() -def test1(spark): + out = df.ak["x"].ak.isclose(df.ak["x"]) + result = out.ak.to_output() + expected = x.ak.isclose(x) + assert result.tolist() == expected.tolist() + + +def test_ufunc(df): + out = np.negative(df.ak["x"].ak) + result = out.ak.to_output() + expected = np.negative(x.ak) + assert result.tolist() == expected.tolist() + + out = np.add(df.ak["x"].ak, df.ak["x"]) + result = out.ak.to_output() + expected = x.ak * 2 + assert result.tolist() == expected.tolist() + + +def test_ufunc_where(df): + out = np.add(df.ak, df, where="x") + result = out.ak.to_output() + expected = x.ak * 2 + assert result.x.tolist() == expected.tolist() + + +def test_overload(spark): x = pd.Series([1, 2, 3]) - df = spark.createDataFrame(pd.DataFrame(x, columns=["x"])) - out = df.ak.is_none.collect() - assert out.tolist() == [False, False, False] + df = spark.createDataFrame(pd.DataFrame(x, columns=["_ak_series_"])) + + out = df.ak + 1 # scalar + result = out.ak.to_output() + assert result.tolist() == [2, 3, 4] + + out = df.ak + [1] # array-like, broadcast + result = out.ak.to_output() + assert result.tolist() == [2, 3, 4] + + out = df.ak == df.ak # matching layout + result = out.ak.to_output() + assert result.tolist() == [True, True, True] + + +def test_dir(df): + assert "flatten" in dir(df.ak) + assert "upper" in dir(df.ak.str) + + +def test_apply_numba(df): + numba = pytest.importorskip("numba") + + def f(data: ak.Array, builder: ak.ArrayBuilder) -> None: + for i, item in enumerate(data.x): + if item[0] is None: + builder.append(None) + else: + builder.append(item[0][2] + item[2][0]) # always 6 + + def f2(data): + if len(data): + builder = ak.ArrayBuilder() + numba.njit(f)(data, builder) + return builder.snapshot() + else: + # default output for zero-length schema guesser + return ak.Array([None, 6]) + + out = df.ak.apply(f2, where="x") + result = out.ak.to_output() + assert result.ak.tolist() == [6, None] * 100