From eb0d49f7c7f0a66af8d438da8da14b82368f3ac3 Mon Sep 17 00:00:00 2001 From: German Bernadskiy Date: Mon, 27 Nov 2023 17:45:55 +1000 Subject: [PATCH 1/4] added test case --- tests/test_api/test_api_sqla_with_includes.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index 0ad62e83..f5969b0e 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -2520,6 +2520,54 @@ async def test_filter_none_instead_of_uuid( ], } + async def test_join_by_relationships_works_correctly_with_many_filters_for_one_field( + self, + app: FastAPI, + async_session: AsyncSession, + client: AsyncClient, + user_1: User, + user_1_post: PostComment, + ): + comment_1 = PostComment( + text=fake.sentence(), + post_id=user_1_post.id, + author_id=user_1.id, + ) + comment_2 = PostComment( + text=fake.sentence(), + post_id=user_1_post.id, + author_id=user_1.id, + ) + assert comment_1.text != comment_2.text + async_session.add_all([comment_1, comment_2]) + await async_session.commit() + + params = { + "filter": dumps( + [ + { + "name": "posts.comments.text", + "op": "eq", + "val": comment_1.text, + }, + { + "name": "posts.comments.text", + "op": "eq", + "val": comment_2.text, + }, + ], + ), + } + + url = app.url_path_for("get_user_list") + res = await client.get(url, params=params) + assert res.status_code == status.HTTP_200_OK, res.text + assert res.json() == { + "data": [], + "jsonapi": {"version": "1.0"}, + "meta": {"count": 1, "totalPages": 1}, + } + ASCENDING = "" DESCENDING = "-" From 5e25556c1a57beb54d1d3a4af9ff5bfe79657d89 Mon Sep 17 00:00:00 2001 From: German Bernadskiy Date: Thu, 30 Nov 2023 16:11:07 +1000 Subject: [PATCH 2/4] rewrote filtering logic --- .../data_layers/filtering/sqlalchemy.py | 683 ++++++++---------- fastapi_jsonapi/data_layers/sqla_orm.py | 14 +- tests/test_api/test_api_sqla_with_includes.py | 2 +- 3 files changed, 301 insertions(+), 398 deletions(-) diff --git a/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py b/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py index 866345c8..f427b137 100644 --- a/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py +++ b/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py @@ -1,458 +1,353 @@ """Helper to create sqlalchemy filters according to filter querystring parameter""" -import inspect -import logging from typing import ( Any, Callable, Dict, List, - Optional, - Tuple, + Set, Type, Union, ) -from pydantic import BaseConfig, BaseModel +from pydantic import BaseModel from pydantic.fields import ModelField -from pydantic.validators import _VALIDATORS, find_validators from sqlalchemy import and_, not_, or_ -from sqlalchemy.orm import InstrumentedAttribute, aliased -from sqlalchemy.sql.elements import BinaryExpression +from sqlalchemy.orm import aliased +from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.orm.util import AliasedClass +from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList -from fastapi_jsonapi.data_layers.shared import create_filters_or_sorts from fastapi_jsonapi.data_typing import TypeModel, TypeSchema from fastapi_jsonapi.exceptions import InvalidFilters, InvalidType -from fastapi_jsonapi.exceptions.json_api import HTTPException from fastapi_jsonapi.schema import get_model_field, get_relationships -from fastapi_jsonapi.splitter import SPLIT_REL -from fastapi_jsonapi.utils.sqla import get_related_model_cls -log = logging.getLogger(__name__) +RELATIONSHIP_SPLITTER = "." -Filter = BinaryExpression -Join = List[Any] +RelationshipPath = str -FilterAndJoins = Tuple[ - Filter, - List[Join], -] -# The mapping with validators using by to cast raw value to instance of target type -REGISTERED_PYDANTIC_TYPES: Dict[Type, List[Callable]] = dict(_VALIDATORS) +class RelationshipInfo(BaseModel): + target_schema: Type[TypeSchema] + model: Type[TypeModel] + aliased_model: AliasedClass + column: InstrumentedAttribute -cast_failed = object() + class Config: + arbitrary_types_allowed = True -def create_filters(model: Type[TypeModel], filter_info: Union[list, dict], schema: Type[TypeSchema]): +def build_filter_expression( + schema_field: ModelField, + model_column: InstrumentedAttribute, + operator: str, + value: Any, +) -> BinaryExpression: """ - Apply filters from filters information to base query + Builds sqlalchemy filter expression, like YourModel.some_field == value - :param model: the model of the node - :param filter_info: current node filter information - :param schema: the resource - """ - return create_filters_or_sorts(model, filter_info, Node, schema) - - -class Node: - """Helper to recursively create filters with sqlalchemy according to filter querystring parameter""" - - def __init__(self, model: Type[TypeModel], filter_: dict, schema: Type[TypeSchema]) -> None: - """ - Initialize an instance of a filter node - - :param model: an sqlalchemy model - :param dict filter_: filters information of the current node and deeper nodes - :param schema: the serializer - """ - self.model = model - self.filter_ = filter_ - self.schema = schema - - def _check_can_be_none(self, fields: list[ModelField]) -> bool: - """ - Return True if None is possible value for target field - """ - return any(field_item.allow_none for field_item in fields) - - def _cast_value_with_scheme(self, field_types: List[ModelField], value: Any) -> Tuple[Any, List[str]]: - errors: List[str] = [] - casted_value = cast_failed - - for field_type in field_types: - try: - if isinstance(value, list): # noqa: SIM108 - casted_value = [field_type(item) for item in value] - else: - casted_value = field_type(value) - except (TypeError, ValueError) as ex: - errors.append(str(ex)) - - return casted_value, errors - - def create_filter(self, schema_field: ModelField, model_column, operator, value): - """ - Create sqlalchemy filter - - :param schema_field: - :param model_column: column sqlalchemy - :param operator: - :param value: - :return: - """ - """ - Custom sqlachemy filtering logic can be created in a schemas field for any operator - To implement a new filtering logic (override existing or create a new one) - create a method inside a field following this pattern: - `__sql_filter_`. Each filtering method has to accept these params: - * schema_field - schemas field instance - * model_column - sqlalchemy column instance - * value - filtering value - * operator - your operator, for example: "eq", "in", "ilike_str_array", ... - """ - # Here we have to deserialize and validate fields, that are used in filtering, - # so the Enum fields are loaded correctly - - if schema_field.sub_fields: # noqa: SIM108 - # Для случаев когда в схеме тип Union - fields = list(schema_field.sub_fields) - else: - fields = [schema_field] - - can_be_none = self._check_can_be_none(fields) - - if value is None: - if can_be_none: - return getattr(model_column, self.operator)(value) - - raise InvalidFilters(detail=f"The field `{schema_field.name}` can't be null") - - types = [i.type_ for i in fields] - clear_value = None - errors: List[str] = [] - - pydantic_types, userspace_types = self._separate_types(types) + Custom sqlalchemy filtering logic can be created in a schemas field for any operator + To implement a new filtering logic (override existing or create a new one) + create a method inside a field following this pattern: `__sql_filter_` - if pydantic_types: - func = self._cast_value_with_pydantic - if isinstance(value, list): - func = self._cast_iterable_with_pydantic - clear_value, errors = func(pydantic_types, value, schema_field) + :param schema_field: schemas field instance + :param model_column: sqlalchemy column instance + :param operator: your operator, for example: "eq", "in", "ilike_str_array", ... + :param value: filtering value - if clear_value is None and userspace_types: - log.warning("Filtering by user type values is not properly tested yet. Use this on your own risk.") - - clear_value, errors = self._cast_value_with_scheme(types, value) - - if clear_value is cast_failed: - raise InvalidType( - detail=f"Can't cast filter value `{value}` to arbitrary type.", - errors=[HTTPException(status_code=InvalidType.status_code, detail=str(err)) for err in errors], - ) + """ + fields = [schema_field] - # Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку) - if clear_value is None and not can_be_none: - raise InvalidType( - detail=", ".join(errors), - pointer=schema_field.name, - ) + # for Union annotations + if schema_field.sub_fields: + fields = list(schema_field.sub_fields) - return getattr(model_column, self.operator)(clear_value) - - def _separate_types(self, types: List[Type]) -> Tuple[List[Type], List[Type]]: - """ - Separates the types into two kinds. - - The first are those for which - there are already validators defined by pydantic - str, int, datetime - and some other built-in types. - The second are all other types for which - the `arbitrary_types_allowed` config is applied when defining the pydantic model - """ - pydantic_types = [ - # skip format - type_ - for type_ in types - if type_ in REGISTERED_PYDANTIC_TYPES - ] - userspace_types = [ - # skip format - type_ - for type_ in types - if type_ not in REGISTERED_PYDANTIC_TYPES - ] - return pydantic_types, userspace_types - - def _validator_requires_model_field(self, validator: Callable) -> bool: - """ - Check if validator accepts the `field` param - - :param validator: - :return: - """ - signature = inspect.signature(validator) - parameters = signature.parameters - - if "field" not in parameters: - return False - - field_param = parameters["field"] - field_type = field_param.annotation - - return field_type == "ModelField" or field_type is ModelField - - def _cast_value_with_pydantic( - self, - types: List[Type], - value: Any, - schema_field: ModelField, - ) -> Tuple[Optional[Any], List[str]]: - result_value, errors = None, [] - - for type_to_cast in types: - for validator in find_validators(type_to_cast, BaseConfig): - args = [value] - # TODO: some other way to get all the validator's dependencies? - if self._validator_requires_model_field(validator): - args.append(schema_field) - try: - result_value = validator(*args) - except Exception as ex: - errors.append(str(ex)) - else: - return result_value, errors - - return None, errors - - def _cast_iterable_with_pydantic( - self, - types: List[Type], - values: List, - schema_field: ModelField, - ) -> Tuple[List, List[str]]: - type_cast_failed = False - failed_values = [] - - result_values: List[Any] = [] - errors: List[str] = [] - - for value in values: - casted_value, cast_errors = self._cast_value_with_pydantic( - types, - value, - schema_field, - ) - errors.extend(cast_errors) + casted_value = None + errors: List[str] = [] - if casted_value is None: - type_cast_failed = True - failed_values.append(value) + for cast_type in [field.type_ for field in fields]: + try: + casted_value = [cast_type(item) for item in value] if isinstance(value, list) else cast_type(value) + except (TypeError, ValueError) as ex: + errors.append(str(ex)) - continue + all_fields_required = all(field.required for field in fields) - result_values.append(casted_value) + if casted_value is None and all_fields_required: + raise InvalidType(detail=", ".join(errors)) - if type_cast_failed: - msg = f"Can't parse items {failed_values} of value {values}" - raise InvalidFilters(msg, pointer=schema_field.name) + return getattr(model_column, operator)(casted_value) - return result_values, errors - def resolve(self) -> FilterAndJoins: # noqa: PLR0911 - """Create filter for a particular node of the filter tree""" - if "or" in self.filter_: - return self._create_filters(type_filter="or") - if "and" in self.filter_: - return self._create_filters(type_filter="and") - if "not" in self.filter_: - filter_, joins = Node(self.model, self.filter_["not"], self.schema).resolve() - return not_(filter_), joins +def is_terminal_node(filter_item: dict) -> bool: + """ + If node shape is: + { + "name: ..., + "op: ..., + "val: ..., + } + """ + terminal_node_keys = {"name", "op", "val"} + return set(filter_item.keys()) == terminal_node_keys - value = self.value - operator = self.filter_["op"] - schema_field: ModelField = self.schema.__fields__[self.name] - custom_filter = schema_field.field_info.extra.get(f"_{operator}_sql_filter_") - if custom_filter: - return custom_filter( - schema_field=schema_field, - model_column=self.column, - value=value, - operator=operator, - ) +def is_relationship_filter(name: str) -> bool: + return RELATIONSHIP_SPLITTER in name - if SPLIT_REL in self.filter_.get("name", ""): - value = { - "name": SPLIT_REL.join(self.filter_["name"].split(SPLIT_REL)[1:]), - "op": operator, - "val": value, - } - return self._relationship_filtering(value) - if isinstance(value, dict): - return self._relationship_filtering(value) +def gather_relationship_paths(filter_item: Union[List, Dict]) -> Set[str]: + """ + Extracts relationship paths from query filter + """ + names = set() - if schema_field.sub_fields: # noqa: SIM108 - # Для случаев когда в схеме тип Union - types = [i.type_ for i in schema_field.sub_fields] - else: - types = [schema_field.type_] - for i_type in types: - try: - if issubclass(i_type, BaseModel): - value = { - "name": self.name, - "op": operator, - "val": value, - } - return self._relationship_filtering(value) - except (TypeError, ValueError): - pass - - return ( - self.create_filter( - schema_field=schema_field, - model_column=self.column, - operator=operator, - value=value, - ), - [], - ) + if isinstance(filter_item, list): + for sub_item in filter_item: + names.update(gather_relationship_paths(sub_item)) - def _relationship_filtering(self, value): - alias = aliased(self.related_model) - joins = [[alias, self.column]] - node = Node(alias, value, self.related_schema) - filters, new_joins = node.resolve() - joins.extend(new_joins) - return filters, joins - - def _create_filters(self, type_filter: str) -> FilterAndJoins: - """ - Create or / and filters - - :param type_filter: 'or' или 'and' - :return: - """ - nodes = [Node(self.model, filter_, self.schema).resolve() for filter_ in self.filter_[type_filter]] - joins = [] - for i_node in nodes: - joins.extend(i_node[1]) - op = and_ if type_filter == "and" else or_ - return op(*[i_node[0] for i_node in nodes]), joins - - @property - def name(self) -> str: - """ - Return the name of the node or raise a BadRequest exception - - :return str: the name of the field to filter on - """ - name = self.filter_.get("name") - - if name is None: - msg = "Can't find name of a filter" - raise InvalidFilters(msg) + elif is_terminal_node(filter_item): + name = filter_item["name"] - if SPLIT_REL in name: - name = name.split(SPLIT_REL)[0] + if RELATIONSHIP_SPLITTER not in name: + return set() - if name not in self.schema.__fields__: - msg = "{} has no attribute {}".format(self.schema.__name__, name) - raise InvalidFilters(msg) + return {RELATIONSHIP_SPLITTER.join(name.split(RELATIONSHIP_SPLITTER)[:-1])} - return name + else: + for sub_item in filter_item.values(): + names.update(gather_relationship_paths(sub_item)) - @property - def op(self) -> str: - """ - Return the operator of the node + return names - :return str: the operator to use in the filter - """ - try: - return self.filter_["op"] - except KeyError: - msg = "Can't find op of a filter" - raise InvalidFilters(msg) - @property - def column(self) -> InstrumentedAttribute: - """Get the column object""" - field = self.name +def get_model_column( + model: Type[TypeModel], + schema: Type[TypeSchema], + field_name: str, +) -> InstrumentedAttribute: + model_field = get_model_field(schema, field_name) - model_field = get_model_field(self.schema, field) + try: + return getattr(model, model_field) + except AttributeError: + msg = "{} has no attribute {}".format(model.__name__, model_field) + raise InvalidFilters(msg) - try: - return getattr(self.model, model_field) - except AttributeError: - msg = "{} has no attribute {}".format(self.model.__name__, model_field) - raise InvalidFilters(msg) - @property - def operator(self) -> name: - """ - Get the function operator from his name +def get_operator(model_column: InstrumentedAttribute, operator_name: str) -> str: + """ + Get the function operator from his name - :return callable: a callable to make operation on a column - """ - operators = (self.op, self.op + "_", "__" + self.op + "__") + :return callable: a callable to make operation on a column + """ + operators = ( + f"__{operator_name}__", + f"{operator_name}_", + operator_name, + ) + + for op in operators: + if hasattr(model_column, op): + return op + + msg = "{} has no operator {}".format(model_column.key, operator_name) + raise InvalidFilters(msg) + + +def get_custom_filter_expression_callable(schema_field, operator: str) -> Callable: + return schema_field.field_info.extra.get( + f"_{operator}_sql_filter_", + ) + + +def gather_relationships_info( + model: Type[TypeModel], + schema: Type[TypeSchema], + relationship_path: List[str], + collected_info: dict, + target_relationship_idx: int = 0, +) -> dict[RelationshipPath, RelationshipInfo]: + is_last_relationship = target_relationship_idx == len(relationship_path) - 1 + target_relationship_path = RELATIONSHIP_SPLITTER.join( + relationship_path[: target_relationship_idx + 1], + ) + target_relationship_name = relationship_path[target_relationship_idx] + + if target_relationship_name not in set(get_relationships(schema)): + msg = f"There are no relationship '{target_relationship_name}' defined in schema {schema.__name__}" + raise InvalidFilters(msg) - for op in operators: - if hasattr(self.column, op): - return op + target_schema = schema.__fields__[target_relationship_name].type_ + target_model = getattr(model, target_relationship_name).property.mapper.class_ + target_column = get_model_column( + model, + schema, + target_relationship_name, + ) + collected_info[target_relationship_path] = RelationshipInfo( + target_schema=target_schema, + model=target_model, + aliased_model=aliased(target_model), + column=target_column, + ) + + if not is_last_relationship: + return gather_relationships_info( + target_model, + target_schema, + relationship_path, + collected_info, + target_relationship_idx + 1, + ) - msg = "{} has no operator {}".format(self.column.key, self.op) - raise InvalidFilters(msg) + return collected_info + + +def gather_relationships( + entrypoint_model: Type[TypeModel], + schema: Type[TypeSchema], + relationship_paths: Set[str], +) -> dict[RelationshipPath, RelationshipInfo]: + collected_info = {} + for relationship_path in sorted(relationship_paths): + gather_relationships_info( + model=entrypoint_model, + schema=schema, + relationship_path=relationship_path.split(RELATIONSHIP_SPLITTER), + collected_info=collected_info, + ) - @property - def value(self) -> Union[dict, list, int, str, float]: - """ - Get the value to filter on - - :return: the value to filter on - """ - if self.filter_.get("field") is not None: - try: - result = getattr(self.model, self.filter_["field"]) - except AttributeError: - msg = "{} has no attribute {}".format(self.model.__name__, self.filter_["field"]) - raise InvalidFilters(msg) - else: - return result + return collected_info + + +def prepare_relationships_info( + model: Type[TypeModel], + schema: Type[TypeSchema], + filter_info: list, +): + # TODO: do this on application startup or use the cache + relationship_paths = gather_relationship_paths(filter_info) + return gather_relationships( + entrypoint_model=model, + schema=schema, + relationship_paths=relationship_paths, + ) + + +def build_filter_expressions( + filter_item: Union[dict, list], + target_schema: Type[TypeSchema], + target_model: Type[TypeModel], + relationships_info: dict[RelationshipPath, RelationshipInfo], +) -> Union[BinaryExpression, BooleanClauseList]: + """ + Builds sqlalchemy expression which can be use + in where condition: query(Model).where(build_filter_expressions(...)) + """ + if is_terminal_node(filter_item): + name = filter_item["name"] + target_schema = target_schema + + if is_relationship_filter(name): + *relationship_path, field_name = name.split(RELATIONSHIP_SPLITTER) + relationship_info: RelationshipInfo = relationships_info[RELATIONSHIP_SPLITTER.join(relationship_path)] + model_column = get_model_column( + model=relationship_info.aliased_model, + schema=relationship_info.target_schema, + field_name=field_name, + ) + target_schema = relationship_info.target_schema else: - if "val" not in self.filter_: - msg = "Can't find value or field in a filter" - raise InvalidFilters(msg) - - return self.filter_["val"] + field_name = name + model_column = get_model_column( + model=target_model, + schema=target_schema, + field_name=field_name, + ) - @property - def related_model(self): - """ - Get the related model of a relationship field + schema_field = target_schema.__fields__[field_name] - :return DeclarativeMeta: the related model - """ - relationship_field = self.name + custom_filter_expression = get_custom_filter_expression_callable( + schema_field=schema_field, + operator=filter_item["op"], + ) + if custom_filter_expression: + return custom_filter_expression( + schema_field=schema_field, + model_column=model_column, + value=filter_item["val"], + operator=filter_item["op"], + ) + else: + return build_filter_expression( + schema_field=schema_field, + model_column=model_column, + operator=get_operator( + model_column=model_column, + operator_name=filter_item["op"], + ), + value=filter_item["val"], + ) - if relationship_field not in get_relationships(self.schema): - msg = "{} has no relationship attribute {}".format(self.schema.__name__, relationship_field) + if isinstance(filter_item, dict): + sqla_logic_operators = { + "or": or_, + "and": and_, + "not": not_, + } + + if len(logic_operators := set(filter_item.keys())) > 1: + msg = ( + f"In each logic node expected one of operators: {set(sqla_logic_operators.keys())} " + f"but got {len(logic_operators)}: {logic_operators}" + ) raise InvalidFilters(msg) - return get_related_model_cls(self.model, get_model_field(self.schema, relationship_field)) + if (logic_operator := logic_operators.pop()) not in set(sqla_logic_operators.keys()): + msg = f"Not found logic operator {logic_operator} expected one of {set(sqla_logic_operators.keys())}" + raise InvalidFilters(msg) - @property - def related_schema(self): - """ - Get the related schema of a relationship field + op = sqla_logic_operators[logic_operator] - :return Schema: the related schema - """ - relationship_field = self.name + if logic_operator == "not": + return op( + build_filter_expressions( + filter_item=filter_item[logic_operator], + target_schema=target_schema, + target_model=target_model, + relationships_info=relationships_info, + ), + ) - if relationship_field not in get_relationships(self.schema): - msg = "{} has no relationship attribute {}".format(self.schema.__name__, relationship_field) - raise InvalidFilters(msg) + expressions = [] + for filter_sub_item in filter_item[logic_operator]: + expressions.append( + build_filter_expressions( + filter_item=filter_sub_item, + target_schema=target_schema, + target_model=target_model, + relationships_info=relationships_info, + ), + ) - return self.schema.__fields__[relationship_field].type_ + return op(*expressions) + + +def create_filters_and_joins( + filter_info: list, + model: Type[TypeModel], + schema: Type[TypeSchema], +): + relationships_info = prepare_relationships_info( + model=model, + schema=schema, + filter_info=filter_info, + ) + expressions = build_filter_expressions( + filter_item={"and": filter_info}, + target_model=model, + target_schema=schema, + relationships_info=relationships_info, + ) + joins = [(info.aliased_model, info.column) for info in relationships_info.values()] + return expressions, joins diff --git a/fastapi_jsonapi/data_layers/sqla_orm.py b/fastapi_jsonapi/data_layers/sqla_orm.py index 4a86ef34..6027f5cc 100644 --- a/fastapi_jsonapi/data_layers/sqla_orm.py +++ b/fastapi_jsonapi/data_layers/sqla_orm.py @@ -13,7 +13,9 @@ from fastapi_jsonapi import BadRequest from fastapi_jsonapi.data_layers.base import BaseDataLayer -from fastapi_jsonapi.data_layers.filtering.sqlalchemy import create_filters +from fastapi_jsonapi.data_layers.filtering.sqlalchemy import ( + create_filters_and_joins, +) from fastapi_jsonapi.data_layers.sorting.sqlalchemy import create_sorts from fastapi_jsonapi.data_typing import TypeModel, TypeSchema from fastapi_jsonapi.exceptions import ( @@ -626,10 +628,16 @@ def filter_query(self, query: "Select", filter_info: Optional[list]) -> "Select" :return: the sorted query. """ if filter_info: - filters, joins = create_filters(model=self.model, filter_info=filter_info, schema=self.schema) + filters, joins = create_filters_and_joins( + model=self.model, + filter_info=filter_info, + schema=self.schema, + ) + for i_join in joins: query = query.join(*i_join) - query = query.where(*filters) + + query = query.where(filters) return query diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index f5969b0e..9b2029df 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -2565,7 +2565,7 @@ async def test_join_by_relationships_works_correctly_with_many_filters_for_one_f assert res.json() == { "data": [], "jsonapi": {"version": "1.0"}, - "meta": {"count": 1, "totalPages": 1}, + "meta": {"count": 0, "totalPages": 1}, } From 3fae8f0b0b7dbe107bd456ef10edae704130fa3b Mon Sep 17 00:00:00 2001 From: German Bernadskiy Date: Tue, 16 Jan 2024 14:08:20 +1000 Subject: [PATCH 3/4] rebase --- .../data_layers/filtering/sqlalchemy.py | 196 ++++++++++++++++-- .../test_filtering/test_sqlalchemy.py | 48 ++--- 2 files changed, 198 insertions(+), 46 deletions(-) diff --git a/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py b/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py index f427b137..08051c99 100644 --- a/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py +++ b/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py @@ -1,16 +1,21 @@ """Helper to create sqlalchemy filters according to filter querystring parameter""" +import inspect +import logging from typing import ( Any, Callable, Dict, List, + Optional, Set, + Tuple, Type, Union, ) -from pydantic import BaseModel +from pydantic import BaseConfig, BaseModel from pydantic.fields import ModelField +from pydantic.validators import _VALIDATORS, find_validators from sqlalchemy import and_, not_, or_ from sqlalchemy.orm import aliased from sqlalchemy.orm.attributes import InstrumentedAttribute @@ -19,14 +24,22 @@ from fastapi_jsonapi.data_typing import TypeModel, TypeSchema from fastapi_jsonapi.exceptions import InvalidFilters, InvalidType +from fastapi_jsonapi.exceptions.json_api import HTTPException from fastapi_jsonapi.schema import get_model_field, get_relationships +log = logging.getLogger(__name__) + RELATIONSHIP_SPLITTER = "." +# The mapping with validators using by to cast raw value to instance of target type +REGISTERED_PYDANTIC_TYPES: Dict[Type, List[Callable]] = dict(_VALIDATORS) + +cast_failed = object() + RelationshipPath = str -class RelationshipInfo(BaseModel): +class RelationshipFilteringInfo(BaseModel): target_schema: Type[TypeSchema] model: Type[TypeModel] aliased_model: AliasedClass @@ -36,6 +49,129 @@ class Config: arbitrary_types_allowed = True +def check_can_be_none(fields: list[ModelField]) -> bool: + """ + Return True if None is possible value for target field + """ + return any(field_item.allow_none for field_item in fields) + + +def separate_types(types: List[Type]) -> Tuple[List[Type], List[Type]]: + """ + Separates the types into two kinds. + + The first are those for which there are already validators + defined by pydantic - str, int, datetime and some other built-in types. + The second are all other types for which the `arbitrary_types_allowed` + config is applied when defining the pydantic model + """ + pydantic_types = [ + # skip format + type_ + for type_ in types + if type_ in REGISTERED_PYDANTIC_TYPES + ] + userspace_types = [ + # skip format + type_ + for type_ in types + if type_ not in REGISTERED_PYDANTIC_TYPES + ] + return pydantic_types, userspace_types + + +def validator_requires_model_field(validator: Callable) -> bool: + """ + Check if validator accepts the `field` param + + :param validator: + :return: + """ + signature = inspect.signature(validator) + parameters = signature.parameters + + if "field" not in parameters: + return False + + field_param = parameters["field"] + field_type = field_param.annotation + + return field_type == "ModelField" or field_type is ModelField + + +def cast_value_with_pydantic( + types: List[Type], + value: Any, + schema_field: ModelField, +) -> Tuple[Optional[Any], List[str]]: + result_value, errors = None, [] + + for type_to_cast in types: + for validator in find_validators(type_to_cast, BaseConfig): + args = [value] + # TODO: some other way to get all the validator's dependencies? + if validator_requires_model_field(validator): + args.append(schema_field) + try: + result_value = validator(*args) + except Exception as ex: + errors.append(str(ex)) + else: + return result_value, errors + + return None, errors + + +def cast_iterable_with_pydantic( + types: List[Type], + values: List, + schema_field: ModelField, +) -> Tuple[List, List[str]]: + type_cast_failed = False + failed_values = [] + + result_values: List[Any] = [] + errors: List[str] = [] + + for value in values: + casted_value, cast_errors = cast_value_with_pydantic( + types, + value, + schema_field, + ) + errors.extend(cast_errors) + + if casted_value is None: + type_cast_failed = True + failed_values.append(value) + + continue + + result_values.append(casted_value) + + if type_cast_failed: + msg = f"Can't parse items {failed_values} of value {values}" + raise InvalidFilters(msg, pointer=schema_field.name) + + return result_values, errors + + +def cast_value_with_scheme(field_types: List[Type], value: Any) -> Tuple[Any, List[str]]: + errors: List[str] = [] + casted_value = cast_failed + + for field_type in field_types: + try: + if isinstance(value, list): # noqa: SIM108 + casted_value = [field_type(item) for item in value] + else: + casted_value = field_type(value) + except (TypeError, ValueError) as ex: + errors.append(str(ex)) + + return casted_value, errors + + def build_filter_expression( schema_field: ModelField, model_column: InstrumentedAttribute, @@ -61,19 +197,43 @@ def build_filter_expression( if schema_field.sub_fields: fields = list(schema_field.sub_fields) + can_be_none = check_can_be_none(fields) + + if value is None: + if can_be_none: + return getattr(model_column, operator)(value) + + raise InvalidFilters(detail=f"The field `{schema_field.name}` can't be null") + + types = [i.type_ for i in fields] casted_value = None errors: List[str] = [] - for cast_type in [field.type_ for field in fields]: - try: - casted_value = [cast_type(item) for item in value] if isinstance(value, list) else cast_type(value) - except (TypeError, ValueError) as ex: - errors.append(str(ex)) + pydantic_types, userspace_types = separate_types(types) + + if pydantic_types: + func = cast_value_with_pydantic + if isinstance(value, list): + func = cast_iterable_with_pydantic + casted_value, errors = func(pydantic_types, value, schema_field) - all_fields_required = all(field.required for field in fields) + if casted_value is None and userspace_types: + log.warning("Filtering by user type values is not properly tested yet. Use this on your own risk.") - if casted_value is None and all_fields_required: - raise InvalidType(detail=", ".join(errors)) + casted_value, errors = cast_value_with_scheme(types, value) + + if casted_value is cast_failed: + raise InvalidType( + detail=f"Can't cast filter value `{value}` to arbitrary type.", + errors=[HTTPException(status_code=InvalidType.status_code, detail=str(err)) for err in errors], + ) + + # Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку) + if casted_value is None and not can_be_none: + raise InvalidType( + detail=", ".join(errors), + pointer=schema_field.name, + ) return getattr(model_column, operator)(casted_value) @@ -81,6 +241,7 @@ def build_filter_expression( def is_terminal_node(filter_item: dict) -> bool: """ If node shape is: + { "name: ..., "op: ..., @@ -166,7 +327,7 @@ def gather_relationships_info( relationship_path: List[str], collected_info: dict, target_relationship_idx: int = 0, -) -> dict[RelationshipPath, RelationshipInfo]: +) -> dict[RelationshipPath, RelationshipFilteringInfo]: is_last_relationship = target_relationship_idx == len(relationship_path) - 1 target_relationship_path = RELATIONSHIP_SPLITTER.join( relationship_path[: target_relationship_idx + 1], @@ -184,7 +345,7 @@ def gather_relationships_info( schema, target_relationship_name, ) - collected_info[target_relationship_path] = RelationshipInfo( + collected_info[target_relationship_path] = RelationshipFilteringInfo( target_schema=target_schema, model=target_model, aliased_model=aliased(target_model), @@ -207,7 +368,7 @@ def gather_relationships( entrypoint_model: Type[TypeModel], schema: Type[TypeSchema], relationship_paths: Set[str], -) -> dict[RelationshipPath, RelationshipInfo]: +) -> dict[RelationshipPath, RelationshipFilteringInfo]: collected_info = {} for relationship_path in sorted(relationship_paths): gather_relationships_info( @@ -238,19 +399,22 @@ def build_filter_expressions( filter_item: Union[dict, list], target_schema: Type[TypeSchema], target_model: Type[TypeModel], - relationships_info: dict[RelationshipPath, RelationshipInfo], + relationships_info: dict[RelationshipPath, RelationshipFilteringInfo], ) -> Union[BinaryExpression, BooleanClauseList]: """ + Return sqla expressions. + Builds sqlalchemy expression which can be use in where condition: query(Model).where(build_filter_expressions(...)) """ if is_terminal_node(filter_item): name = filter_item["name"] - target_schema = target_schema if is_relationship_filter(name): *relationship_path, field_name = name.split(RELATIONSHIP_SPLITTER) - relationship_info: RelationshipInfo = relationships_info[RELATIONSHIP_SPLITTER.join(relationship_path)] + relationship_info: RelationshipFilteringInfo = relationships_info[ + RELATIONSHIP_SPLITTER.join(relationship_path) + ] model_column = get_model_column( model=relationship_info.aliased_model, schema=relationship_info.target_schema, diff --git a/tests/test_data_layers/test_filtering/test_sqlalchemy.py b/tests/test_data_layers/test_filtering/test_sqlalchemy.py index 18f5e4ef..8c3e4949 100644 --- a/tests/test_data_layers/test_filtering/test_sqlalchemy.py +++ b/tests/test_data_layers/test_filtering/test_sqlalchemy.py @@ -1,47 +1,41 @@ from typing import Any -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock from fastapi import status from pydantic import BaseModel from pytest import raises # noqa PT013 -from fastapi_jsonapi.data_layers.filtering.sqlalchemy import Node -from fastapi_jsonapi.exceptions.json_api import InvalidType +from fastapi_jsonapi.data_layers.filtering.sqlalchemy import ( + build_filter_expression, +) +from fastapi_jsonapi.exceptions import InvalidType -class TestNode: +class TestFilteringFuncs: def test_user_type_cast_success(self): class UserType: def __init__(self, *args, **kwargs): - self.value = "success" + pass class ModelSchema(BaseModel): - user_type: UserType + value: UserType class Config: arbitrary_types_allowed = True - node = Node( - model=Mock(), - filter_={ - "name": "user_type", - "op": "eq", - "val": Any, - }, - schema=ModelSchema, - ) - - model_column_mock = Mock() - model_column_mock.eq = lambda clear_value: clear_value + model_column_mock = MagicMock() - clear_value = node.create_filter( - schema_field=ModelSchema.__fields__["user_type"], + build_filter_expression( + schema_field=ModelSchema.__fields__["value"], model_column=model_column_mock, - operator=Mock(), + operator="__eq__", value=Any, ) - assert isinstance(clear_value, UserType) - assert clear_value.value == "success" + + model_column_mock.__eq__.assert_called_once() + + call_arg = model_column_mock.__eq__.call_args[0] + isinstance(call_arg, UserType) def test_user_type_cast_fail(self): class UserType: @@ -55,14 +49,8 @@ class ModelSchema(BaseModel): class Config: arbitrary_types_allowed = True - node = Node( - model=Mock(), - filter_=Mock(), - schema=ModelSchema, - ) - with raises(InvalidType) as exc_info: - node.create_filter( + build_filter_expression( schema_field=ModelSchema.__fields__["user_type"], model_column=Mock(), operator=Mock(), From 5507360b1c9d3d767dbe843126fd543d54289434 Mon Sep 17 00:00:00 2001 From: German Bernadskiy Date: Thu, 18 Jan 2024 13:22:18 +1000 Subject: [PATCH 4/4] issue updates --- fastapi_jsonapi/data_layers/filtering/sqlalchemy.py | 5 +++-- tests/test_data_layers/test_filtering/test_sqlalchemy.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py b/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py index 08051c99..3cdfe097 100644 --- a/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py +++ b/fastapi_jsonapi/data_layers/filtering/sqlalchemy.py @@ -168,6 +168,8 @@ def cast_value_with_scheme(field_types: List[Type], value: Any) -> Tuple[Any, Li casted_value = field_type(value) except (TypeError, ValueError) as ex: errors.append(str(ex)) + else: + return casted_value, errors return casted_value, errors @@ -228,7 +230,6 @@ def build_filter_expression( errors=[HTTPException(status_code=InvalidType.status_code, detail=str(err)) for err in errors], ) - # Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку) if casted_value is None and not can_be_none: raise InvalidType( detail=", ".join(errors), @@ -256,7 +257,7 @@ def is_relationship_filter(name: str) -> bool: return RELATIONSHIP_SPLITTER in name -def gather_relationship_paths(filter_item: Union[List, Dict]) -> Set[str]: +def gather_relationship_paths(filter_item: Union[dict, list]) -> Set[str]: """ Extracts relationship paths from query filter """ diff --git a/tests/test_data_layers/test_filtering/test_sqlalchemy.py b/tests/test_data_layers/test_filtering/test_sqlalchemy.py index 8c3e4949..ec27a528 100644 --- a/tests/test_data_layers/test_filtering/test_sqlalchemy.py +++ b/tests/test_data_layers/test_filtering/test_sqlalchemy.py @@ -15,7 +15,7 @@ class TestFilteringFuncs: def test_user_type_cast_success(self): class UserType: def __init__(self, *args, **kwargs): - pass + """This method is needed to handle incoming arguments""" class ModelSchema(BaseModel): value: UserType