Skip to content

Commit

Permalink
use globals instead of threading.local (fixes #281) (#282)
Browse files Browse the repository at this point in the history
* use globals instead of threading.local (fixes #281)

* fix broken tesT

* satisfy mypy
  • Loading branch information
dionhaefner authored Oct 28, 2022
1 parent 9c551bb commit 964a72b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 21 deletions.
15 changes: 8 additions & 7 deletions terracotta/drivers/geotiff_raster_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@

logger = logging.getLogger(__name__)

context = threading.local()
context.executor = None
_executor = None


def create_executor() -> Executor:
Expand All @@ -50,16 +49,18 @@ def create_executor() -> Executor:


def submit_to_executor(task: Callable[..., Any]) -> Future:
if context.executor is None:
context.executor = create_executor()
global _executor

if _executor is None:
_executor = create_executor()

try:
future = context.executor.submit(task)
future = _executor.submit(task)
except BrokenProcessPool:
# re-create executor and try again
logger.warn('Re-creating broken process pool')
context.executor = create_executor()
future = context.executor.submit(task)
_executor = create_executor()
future = _executor.submit(task)

return future

Expand Down
2 changes: 1 addition & 1 deletion terracotta/drivers/relational_meta_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def version_tuple(version_string: str) -> Sequence[str]:
)
self.db_version_verified = True

@property # type: ignore
@property
@requires_connection
@convert_exceptions(_ERROR_ON_CONNECT)
def db_version(self) -> str:
Expand Down
27 changes: 14 additions & 13 deletions tests/drivers/test_raster_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,31 +504,32 @@ def test_nodata_consistency(driver_path, provider, big_raster_file_mask, big_ras


@pytest.mark.parametrize('provider', DRIVERS)
def test_broken_process_pool(driver_path, provider, raster_file):
def test_broken_process_pool(monkeypatch, driver_path, provider, raster_file):
import concurrent.futures
import terracotta.drivers.geotiff_raster_store
from terracotta import drivers
from terracotta.drivers.geotiff_raster_store import context

class BrokenPool:
def submit(self, *args, **kwargs):
raise concurrent.futures.process.BrokenProcessPool('monkeypatched')

context.executor = BrokenPool()
with monkeypatch.context() as m:
m.setattr(terracotta.drivers.geotiff_raster_store, '_executor', BrokenPool())

db = drivers.get_driver(driver_path, provider=provider)
keys = ('some', 'keynames')
db = drivers.get_driver(driver_path, provider=provider)
keys = ('some', 'keynames')

db.create(keys)
db.insert(['some', 'value'], str(raster_file))
db.insert(['some', 'other_value'], str(raster_file))
db.create(keys)
db.insert(['some', 'value'], str(raster_file))
db.insert(['some', 'other_value'], str(raster_file))

data1 = db.get_raster_tile(['some', 'value'], tile_size=(256, 256))
assert data1.shape == (256, 256)
data1 = db.get_raster_tile(['some', 'value'], tile_size=(256, 256))
assert data1.shape == (256, 256)

data2 = db.get_raster_tile(['some', 'other_value'], tile_size=(256, 256))
assert data2.shape == (256, 256)
data2 = db.get_raster_tile(['some', 'other_value'], tile_size=(256, 256))
assert data2.shape == (256, 256)

np.testing.assert_array_equal(data1, data2)
np.testing.assert_array_equal(data1, data2)


def test_no_multiprocessing():
Expand Down

0 comments on commit 964a72b

Please sign in to comment.