Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/crawlee/_utils/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from asyncio import sleep as _retry_sleep # Using alias for testing purposes
from datetime import timedelta
from functools import wraps
from typing import TYPE_CHECKING, ParamSpec, TypeVar

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable

P = ParamSpec('P')
T = TypeVar('T')


def retry_on_error(
*exception_types: type[Exception],
max_attempts: int = 3,
base_delay: timedelta = timedelta(milliseconds=500),
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
"""Retry an async function with exponential backoff on specified exceptions.

Args:
*exception_types: Exception types to catch and retry on.
max_attempts: Maximum number of attempts including the first one.
base_delay: Base delay between retries; doubles on each subsequent attempt.
"""
if max_attempts < 1:
raise ValueError('max_attempts must be at least 1')

if base_delay < timedelta(0):
raise ValueError('base_delay must be a non-negative timedelta')

if not exception_types:
raise ValueError('At least one exception type must be specified')

def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
base_delay_seconds = base_delay.total_seconds()
for attempt in range(max_attempts):
try:
return await func(*args, **kwargs)
except exception_types: # noqa: PERF203
if attempt >= max_attempts - 1:
raise
await _retry_sleep(base_delay_seconds * (2**attempt))
raise RuntimeError('Unreachable')

return wrapper

return decorator
7 changes: 7 additions & 0 deletions src/crawlee/storage_clients/_redis/_dataset_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from logging import getLogger
from typing import TYPE_CHECKING, Any, cast

from redis.exceptions import RedisError
from typing_extensions import NotRequired, override

from crawlee._utils.retry import retry_on_error
from crawlee.storage_clients._base import DatasetClient
from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata

Expand Down Expand Up @@ -102,14 +104,17 @@ async def open(
instance_kwargs={},
)

@retry_on_error(RedisError)
@override
async def get_metadata(self) -> DatasetMetadata:
return await self._get_metadata(DatasetMetadata)

@retry_on_error(RedisError)
@override
async def drop(self) -> None:
await self._drop(extra_keys=[self._items_key])

@retry_on_error(RedisError)
@override
async def purge(self) -> None:
await self._purge(
Expand All @@ -119,6 +124,7 @@ async def purge(self) -> None:
),
)

@retry_on_error(RedisError)
@override
async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None:
if isinstance(data, dict):
Expand All @@ -133,6 +139,7 @@ async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None:
),
)

@retry_on_error(RedisError)
@override
async def get_data(
self,
Expand Down
10 changes: 9 additions & 1 deletion src/crawlee/storage_clients/_redis/_key_value_store_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from logging import getLogger
from typing import TYPE_CHECKING, Any

from redis.exceptions import RedisError
from typing_extensions import override

from crawlee._utils.file import infer_mime_type
from crawlee._utils.retry import retry_on_error
from crawlee.storage_clients._base import KeyValueStoreClient
from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata

Expand Down Expand Up @@ -100,21 +102,25 @@ async def open(
instance_kwargs={},
)

@retry_on_error(RedisError)
@override
async def get_metadata(self) -> KeyValueStoreMetadata:
return await self._get_metadata(KeyValueStoreMetadata)

@retry_on_error(RedisError)
@override
async def drop(self) -> None:
await self._drop(extra_keys=[self._items_key, self._metadata_items_key])

@retry_on_error(RedisError)
@override
async def purge(self) -> None:
await self._purge(
extra_keys=[self._items_key, self._metadata_items_key],
metadata_kwargs=MetadataUpdateParams(update_accessed_at=True, update_modified_at=True),
)

@retry_on_error(RedisError)
@override
async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None:
# Special handling for None values
Expand All @@ -141,7 +147,6 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No
content_type=content_type,
size=size,
)

async with self._get_pipeline() as pipe:
# redis-py typing issue
await await_redis_response(pipe.hset(self._items_key, key, value_bytes)) # ty: ignore[invalid-argument-type]
Expand All @@ -155,6 +160,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No
)
await self._update_metadata(pipe, **MetadataUpdateParams(update_accessed_at=True, update_modified_at=True))

@retry_on_error(RedisError)
@override
async def get_value(self, *, key: str) -> KeyValueStoreRecord | None:
serialized_metadata_item = await await_redis_response(self._redis.hget(self._metadata_items_key, key))
Expand Down Expand Up @@ -200,6 +206,7 @@ async def get_value(self, *, key: str) -> KeyValueStoreRecord | None:

return KeyValueStoreRecord(value=value, **metadata_item.model_dump())

@retry_on_error(RedisError)
@override
async def delete_value(self, *, key: str) -> None:
async with self._get_pipeline() as pipe:
Expand Down Expand Up @@ -251,6 +258,7 @@ async def iterate_keys(
async def get_public_url(self, *, key: str) -> str:
raise NotImplementedError('Public URLs are not supported for memory key-value stores.')

@retry_on_error(RedisError)
@override
async def record_exists(self, *, key: str) -> bool:
async with self._get_pipeline(with_execute=False) as pipe:
Expand Down
10 changes: 10 additions & 0 deletions src/crawlee/storage_clients/_redis/_request_queue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from logging import getLogger
from typing import TYPE_CHECKING, Any, Literal

from redis.exceptions import RedisError
from typing_extensions import NotRequired, override

from crawlee import Request
from crawlee._utils.crypto import crypto_random_object_id
from crawlee._utils.retry import retry_on_error
from crawlee.storage_clients._base import RequestQueueClient
from crawlee.storage_clients.models import AddRequestsResponse, ProcessedRequest, RequestQueueMetadata

Expand Down Expand Up @@ -207,10 +209,12 @@ async def open(
instance_kwargs={'dedup_strategy': dedup_strategy, 'bloom_error_rate': bloom_error_rate},
)

@retry_on_error(RedisError)
@override
async def get_metadata(self) -> RequestQueueMetadata:
return await self._get_metadata(RequestQueueMetadata)

@retry_on_error(RedisError)
@override
async def drop(self) -> None:
if self._dedup_strategy == 'bloom':
Expand All @@ -222,6 +226,7 @@ async def drop(self) -> None:
extra_keys.extend([self._queue_key, self._data_key, self._in_progress_key])
await self._drop(extra_keys=extra_keys)

@retry_on_error(RedisError)
@override
async def purge(self) -> None:
if self._dedup_strategy == 'bloom':
Expand Down Expand Up @@ -349,6 +354,7 @@ async def add_batch_of_requests(
unprocessed_requests=[],
)

@retry_on_error(RedisError)
@override
async def fetch_next_request(self) -> Request | None:
if self._pending_fetch_cache:
Expand Down Expand Up @@ -377,6 +383,7 @@ async def fetch_next_request(self) -> Request | None:

return requests[0]

@retry_on_error(RedisError)
@override
async def get_request(self, unique_key: str) -> Request | None:
request_data = await await_redis_response(self._redis.hget(self._data_key, unique_key))
Expand All @@ -386,6 +393,7 @@ async def get_request(self, unique_key: str) -> Request | None:

return None

@retry_on_error(RedisError)
@override
async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None:
# Check if the request is in progress.
Expand Down Expand Up @@ -424,6 +432,7 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest |
was_already_handled=True,
)

@retry_on_error(RedisError)
@override
async def reclaim_request(
self,
Expand Down Expand Up @@ -469,6 +478,7 @@ async def reclaim_request(
was_already_handled=False,
)

@retry_on_error(RedisError)
@override
async def is_empty(self) -> bool:
"""Check if the queue is empty.
Expand Down
1 change: 1 addition & 0 deletions src/crawlee/storage_clients/_sql/_client_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ async def get_session(self, *, with_simple_commit: bool = False) -> AsyncIterato
except SQLAlchemyError as e:
logger.warning(f'Error occurred during session transaction: {e}')
await session.rollback()
raise
else:
yield session

Expand Down
7 changes: 7 additions & 0 deletions src/crawlee/storage_clients/_sql/_dataset_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

from sqlalchemy import Select, insert, select
from sqlalchemy import func as sql_func
from sqlalchemy.exc import SQLAlchemyError
from typing_extensions import Self, override

from crawlee._utils.retry import retry_on_error
from crawlee.storage_clients._base import DatasetClient
from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata

Expand Down Expand Up @@ -109,11 +111,13 @@ async def open(
extra_metadata_fields={'item_count': 0},
)

@retry_on_error(SQLAlchemyError)
@override
async def get_metadata(self) -> DatasetMetadata:
# The database is a single place of truth
return await self._get_metadata(DatasetMetadata)

@retry_on_error(SQLAlchemyError)
@override
async def drop(self) -> None:
"""Delete this dataset and all its items from the database.
Expand All @@ -122,6 +126,7 @@ async def drop(self) -> None:
"""
await self._drop()

@retry_on_error(SQLAlchemyError)
@override
async def purge(self) -> None:
"""Remove all items from this dataset while keeping the dataset structure.
Expand All @@ -137,6 +142,7 @@ async def purge(self) -> None:
)
)

@retry_on_error(SQLAlchemyError)
@override
async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None:
if not isinstance(data, list):
Expand All @@ -150,6 +156,7 @@ async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None:

await self._add_buffer_record(session, update_modified_at=True, delta_item_count=len(data))

@retry_on_error(SQLAlchemyError)
@override
async def get_data(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/crawlee/storage_clients/_sql/_key_value_store_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

from sqlalchemy import CursorResult, delete, select
from sqlalchemy import func as sql_func
from sqlalchemy.exc import SQLAlchemyError
from typing_extensions import Self, override

from crawlee._utils.file import infer_mime_type
from crawlee._utils.retry import retry_on_error
from crawlee.storage_clients._base import KeyValueStoreClient
from crawlee.storage_clients.models import (
KeyValueStoreMetadata,
Expand Down Expand Up @@ -117,11 +119,13 @@ async def open(
extra_metadata_fields={},
)

@retry_on_error(SQLAlchemyError)
@override
async def get_metadata(self) -> KeyValueStoreMetadata:
# The database is a single place of truth
return await self._get_metadata(KeyValueStoreMetadata)

@retry_on_error(SQLAlchemyError)
@override
async def drop(self) -> None:
"""Delete this key-value store and all its records from the database.
Expand All @@ -130,6 +134,7 @@ async def drop(self) -> None:
"""
await self._drop()

@retry_on_error(SQLAlchemyError)
@override
async def purge(self) -> None:
"""Remove all items from this key-value store while keeping the key-value store structure.
Expand All @@ -139,6 +144,7 @@ async def purge(self) -> None:
now = datetime.now(timezone.utc)
await self._purge(metadata_kwargs=MetadataUpdateParams(accessed_at=now, modified_at=now))

@retry_on_error(SQLAlchemyError)
@override
async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None:
# Special handling for None values
Expand Down Expand Up @@ -180,6 +186,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No

await self._add_buffer_record(session, update_modified_at=True)

@retry_on_error(SQLAlchemyError)
@override
async def get_value(self, *, key: str) -> KeyValueStoreRecord | None:
# Query the record by key
Expand Down Expand Up @@ -226,6 +233,7 @@ async def get_value(self, *, key: str) -> KeyValueStoreRecord | None:
size=record_db.size,
)

@retry_on_error(SQLAlchemyError)
@override
async def delete_value(self, *, key: str) -> None:
stmt = delete(self._ITEM_TABLE).where(
Expand Down Expand Up @@ -274,6 +282,7 @@ async def iterate_keys(

await self._add_buffer_record(session)

@retry_on_error(SQLAlchemyError)
@override
async def record_exists(self, *, key: str) -> bool:
stmt = select(self._ITEM_TABLE.key).where(
Expand Down
Loading
Loading