Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class Expression(ABC):
together method calls to create complex expressions.
"""

# Controls whether expression methods (e.g., .add(), .multiply()) can be called on
# instances of this class or its subclasses. Set to False for non-computational
# expressions like AliasedExpression.
_supports_expr_methods = True

def __repr__(self):
return f"{self.__class__.__name__}()"

Expand Down Expand Up @@ -113,6 +118,10 @@ def __init__(self, instance_func):
self.instance_func = instance_func

def static_func(self, first_arg, *other_args, **kwargs):
if getattr(first_arg, "_supports_expr_methods", True) is False:
raise TypeError(
f"Cannot call '{self.instance_func.__name__}' on {type(first_arg).__name__}."
)
if not isinstance(first_arg, (Expression, str)):
raise TypeError(
f"'{self.instance_func.__name__}' must be called on an Expression or a string representing a field. got {type(first_arg)}."
Expand All @@ -128,6 +137,10 @@ def __get__(self, instance, owner):
if instance is None:
return self.static_func
else:
if getattr(instance, "_supports_expr_methods", True) is False:
raise TypeError(
f"Cannot call '{self.instance_func.__name__}' on {type(instance).__name__}."
)
return self.instance_func.__get__(instance, owner)

@expose_as_static
Expand Down Expand Up @@ -2715,10 +2728,21 @@ def _to_value(field_list: Sequence[Selectable]) -> Value:
class AliasedExpression(Selectable, Generic[T]):
"""Wraps an expression with an alias."""

_supports_expr_methods = False

def __init__(self, expr: T, alias: str):
if isinstance(expr, AliasedExpression):
raise TypeError(
"Cannot wrap an AliasedExpression with another alias. An alias can only be applied once."
)
self.expr = expr
self.alias = alias

def as_(self, alias: str) -> "AliasedExpression":
raise TypeError(
"Cannot call as_() on an AliasedExpression. An alias can only be applied once."
)

def _to_map(self):
return self.alias, self.expr._to_pb()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -896,3 +896,42 @@ tests:
- Pipeline:
- Subcollection: reviews
assert_error: ".*start of a nested pipeline.*"
- description: cannot_call_expression_methods_on_aliased_expression
pipeline:
- Collection: books
- Select:
- FunctionExpression.add:
- AliasedExpression:
- Field: pages
- pages_alias
- 5
assert_error: "Cannot call 'add' on AliasedExpression"
- description: cannot_chain_aliases
pipeline:
- Collection: books
- Select:
- AliasedExpression:
- AliasedExpression:
- Field: pages
- pages_alias
- final_alias
assert_error: "Cannot wrap an AliasedExpression"
- description: valid_aliased_expression_proto
pipeline:
- Collection: books
- Select:
- AliasedExpression:
- Field: pages
- pages_alias
assert_proto:
pipeline:
stages:
- args:
- referenceValue: /books
name: collection
- args:
- mapValue:
fields:
pages_alias:
fieldReferenceValue: pages
name: select
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import pytest
import yaml
from google.api_core.exceptions import GoogleAPIError
from google.protobuf.json_format import MessageToDict
from test__helpers import FIRESTORE_EMULATOR, FIRESTORE_ENTERPRISE_DB, system_test_lock

Expand Down Expand Up @@ -124,9 +123,9 @@ def test_pipeline_expected_errors(test_dict, client):
Finds assert_error statements in yaml, and ensures the pipeline raises the expected error
"""
error_regex = test_dict["assert_error"]
pipeline = parse_pipeline(client, test_dict["pipeline"])
# check if server responds as expected
with pytest.raises(GoogleAPIError) as err:

with pytest.raises(Exception) as err:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a base Exception is too broad and can mask unrelated bugs or unexpected errors (e.g., NameError, AttributeError, or malformed test data) that might occur during pipeline parsing or execution. Since the expected errors are either client-side validation errors (TypeError) or server-side errors (GoogleAPIError), it is better to catch only those specific types.

Suggested change
with pytest.raises(Exception) as err:
with pytest.raises((GoogleAPIError, TypeError)) as err:

pipeline = parse_pipeline(client, test_dict["pipeline"])
pipeline.execute()
found_error = str(err.value)
match = re.search(error_regex, found_error)
Expand Down Expand Up @@ -215,9 +214,8 @@ async def test_pipeline_expected_errors_async(test_dict, async_client):
Finds assert_error statements in yaml, and ensures the pipeline raises the expected error
"""
error_regex = test_dict["assert_error"]
pipeline = parse_pipeline(async_client, test_dict["pipeline"])
# check if server responds as expected
with pytest.raises(GoogleAPIError) as err:
with pytest.raises(Exception) as err:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a base Exception is too broad. It is recommended to use a more specific tuple of expected exceptions, such as (GoogleAPIError, TypeError), to ensure that the test only passes for the intended failure modes and doesn't swallow unrelated issues.

Suggested change
with pytest.raises(Exception) as err:
with pytest.raises((GoogleAPIError, TypeError)) as err:

pipeline = parse_pipeline(async_client, test_dict["pipeline"])
await pipeline.execute()
found_error = str(err.value)
match = re.search(error_regex, found_error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,24 @@ def test_to_map(self):
assert result[0] == "alias1"
assert result[1] == Value(field_reference_value="field1")

def test_chaining_aliases(self):
with pytest.raises(
TypeError, match="Cannot call as_\\(\\) on an AliasedExpression"
):
Field.of("field1").as_("alias1").as_("alias2")

def test_expr_method_on_aliased_raises_error(self):
with pytest.raises(
TypeError, match="Cannot call 'add' on AliasedExpression"
):
Field.of("field1").as_("alias1").add(5)

def test_static_expr_method_on_aliased_raises_error(self):
with pytest.raises(
TypeError, match="Cannot call 'add' on AliasedExpression"
):
expr.Expression.add(Field.of("field1").as_("alias1"), 5)


class TestBooleanExpression:
def test__from_query_filter_pb_composite_filter_or(self, mock_client):
Expand Down
Loading