diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py index 43d0e4f0af07..5f52d92fd3bd 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py @@ -72,6 +72,11 @@ class Expression(ABC): together method calls to create complex expressions. """ + # Controls whether expression methods (e.g., .add(), .multiply()) can be called on + # instances of this class or its subclasses. Set to False for non-computational + # expressions like AliasedExpression. + _supports_expr_methods = True + def __repr__(self): return f"{self.__class__.__name__}()" @@ -113,6 +118,10 @@ def __init__(self, instance_func): self.instance_func = instance_func def static_func(self, first_arg, *other_args, **kwargs): + if getattr(first_arg, "_supports_expr_methods", True) is False: + raise TypeError( + f"Cannot call '{self.instance_func.__name__}' on {type(first_arg).__name__}." + ) if not isinstance(first_arg, (Expression, str)): raise TypeError( f"'{self.instance_func.__name__}' must be called on an Expression or a string representing a field. got {type(first_arg)}." @@ -128,6 +137,10 @@ def __get__(self, instance, owner): if instance is None: return self.static_func else: + if getattr(instance, "_supports_expr_methods", True) is False: + raise TypeError( + f"Cannot call '{self.instance_func.__name__}' on {type(instance).__name__}." + ) return self.instance_func.__get__(instance, owner) @expose_as_static @@ -2715,10 +2728,21 @@ def _to_value(field_list: Sequence[Selectable]) -> Value: class AliasedExpression(Selectable, Generic[T]): """Wraps an expression with an alias.""" + _supports_expr_methods = False + def __init__(self, expr: T, alias: str): + if isinstance(expr, AliasedExpression): + raise TypeError( + "Cannot wrap an AliasedExpression with another alias. An alias can only be applied once." + ) self.expr = expr self.alias = alias + def as_(self, alias: str) -> "AliasedExpression": + raise TypeError( + "Cannot call as_() on an AliasedExpression. An alias can only be applied once." + ) + def _to_map(self): return self.alias, self.expr._to_pb() diff --git a/packages/google-cloud-firestore/tests/system/pipeline_e2e/general.yaml b/packages/google-cloud-firestore/tests/system/pipeline_e2e/general.yaml index 2c901e6c6f7f..6ac8a784da05 100644 --- a/packages/google-cloud-firestore/tests/system/pipeline_e2e/general.yaml +++ b/packages/google-cloud-firestore/tests/system/pipeline_e2e/general.yaml @@ -896,3 +896,42 @@ tests: - Pipeline: - Subcollection: reviews assert_error: ".*start of a nested pipeline.*" + - description: cannot_call_expression_methods_on_aliased_expression + pipeline: + - Collection: books + - Select: + - FunctionExpression.add: + - AliasedExpression: + - Field: pages + - pages_alias + - 5 + assert_error: "Cannot call 'add' on AliasedExpression" + - description: cannot_chain_aliases + pipeline: + - Collection: books + - Select: + - AliasedExpression: + - AliasedExpression: + - Field: pages + - pages_alias + - final_alias + assert_error: "Cannot wrap an AliasedExpression" + - description: valid_aliased_expression_proto + pipeline: + - Collection: books + - Select: + - AliasedExpression: + - Field: pages + - pages_alias + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + pages_alias: + fieldReferenceValue: pages + name: select diff --git a/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py b/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py index 6881279c665b..149182b13694 100644 --- a/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py +++ b/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py @@ -24,7 +24,6 @@ import pytest import yaml -from google.api_core.exceptions import GoogleAPIError from google.protobuf.json_format import MessageToDict from test__helpers import FIRESTORE_EMULATOR, FIRESTORE_ENTERPRISE_DB, system_test_lock @@ -124,9 +123,9 @@ def test_pipeline_expected_errors(test_dict, client): Finds assert_error statements in yaml, and ensures the pipeline raises the expected error """ error_regex = test_dict["assert_error"] - pipeline = parse_pipeline(client, test_dict["pipeline"]) - # check if server responds as expected - with pytest.raises(GoogleAPIError) as err: + + with pytest.raises(Exception) as err: + pipeline = parse_pipeline(client, test_dict["pipeline"]) pipeline.execute() found_error = str(err.value) match = re.search(error_regex, found_error) @@ -215,9 +214,8 @@ async def test_pipeline_expected_errors_async(test_dict, async_client): Finds assert_error statements in yaml, and ensures the pipeline raises the expected error """ error_regex = test_dict["assert_error"] - pipeline = parse_pipeline(async_client, test_dict["pipeline"]) - # check if server responds as expected - with pytest.raises(GoogleAPIError) as err: + with pytest.raises(Exception) as err: + pipeline = parse_pipeline(async_client, test_dict["pipeline"]) await pipeline.execute() found_error = str(err.value) match = re.search(error_regex, found_error) diff --git a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py index d54b0b626fd6..f76016805729 100644 --- a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py +++ b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py @@ -260,6 +260,24 @@ def test_to_map(self): assert result[0] == "alias1" assert result[1] == Value(field_reference_value="field1") + def test_chaining_aliases(self): + with pytest.raises( + TypeError, match="Cannot call as_\\(\\) on an AliasedExpression" + ): + Field.of("field1").as_("alias1").as_("alias2") + + def test_expr_method_on_aliased_raises_error(self): + with pytest.raises( + TypeError, match="Cannot call 'add' on AliasedExpression" + ): + Field.of("field1").as_("alias1").add(5) + + def test_static_expr_method_on_aliased_raises_error(self): + with pytest.raises( + TypeError, match="Cannot call 'add' on AliasedExpression" + ): + expr.Expression.add(Field.of("field1").as_("alias1"), 5) + class TestBooleanExpression: def test__from_query_filter_pb_composite_filter_or(self, mock_client):