From 80df7fa4910fdc5d331c4b5d78bce073f620ce06 Mon Sep 17 00:00:00 2001 From: German Bernadskiy Date: Wed, 6 Mar 2024 19:15:39 +1000 Subject: [PATCH 1/7] added test case and base logic --- fastapi_jsonapi/querystring.py | 42 +++++++------- fastapi_jsonapi/schema_builder.py | 1 + fastapi_jsonapi/views/list_view.py | 38 ++++++++++++- fastapi_jsonapi/views/utils.py | 12 +++- fastapi_jsonapi/views/view_base.py | 8 ++- tests/test_api/test_api_sqla_with_includes.py | 56 ++++++++++++++++++- 6 files changed, 131 insertions(+), 26 deletions(-) diff --git a/fastapi_jsonapi/querystring.py b/fastapi_jsonapi/querystring.py index 8503e23b..9731c46c 100644 --- a/fastapi_jsonapi/querystring.py +++ b/fastapi_jsonapi/querystring.py @@ -1,4 +1,5 @@ """Helper to deal with querystring parameters according to jsonapi specification.""" +from collections import defaultdict from functools import cached_property from typing import ( TYPE_CHECKING, @@ -22,9 +23,9 @@ ) from starlette.datastructures import QueryParams +from fastapi_jsonapi.api import RoutersJSONAPI from fastapi_jsonapi.exceptions import ( BadRequest, - InvalidField, InvalidFilters, InvalidInclude, InvalidSort, @@ -32,7 +33,6 @@ from fastapi_jsonapi.schema import ( get_model_field, get_relationships, - get_schema_from_type, ) from fastapi_jsonapi.splitter import SPLIT_REL @@ -97,7 +97,7 @@ def _get_key_values(self, name: str) -> Dict[str, Union[List[str], str]]: :return: a dict of key / values items :raises BadRequest: if an error occurred while parsing the querystring. """ - results: Dict[str, Union[List[str], str]] = {} + results = defaultdict(set) for raw_key, value in self.qs.multi_items(): key = unquote(raw_key) @@ -109,10 +109,7 @@ def _get_key_values(self, name: str) -> Dict[str, Union[List[str], str]]: key_end = key.index("]") item_key = key[key_start:key_end] - if "," in value: - results.update({item_key: value.split(",")}) - else: - results.update({item_key: value}) + results[item_key].update(value.split(",")) except Exception: msg = "Parse error" raise BadRequest(msg, parameter=key) @@ -216,27 +213,28 @@ def fields(self) -> Dict[str, List[str]]: :raises InvalidField: if result field not in schema. """ - if self.request.method != "GET": - msg = "attribute 'fields' allowed only for GET-method" - raise InvalidField(msg) fields = self._get_key_values("fields") - for key, value in fields.items(): - if not isinstance(value, list): - value = [value] # noqa: PLW2901 - fields[key] = value + for resource_type, field_names in fields.items(): # TODO: we have registry for models (BaseModel) # TODO: create `type to schemas` registry - schema: Type[BaseModel] = get_schema_from_type(key, self.app) - for field in value: - if field not in schema.__fields__: - msg = "{schema} has no attribute {field}".format( - schema=schema.__name__, - field=field, - ) - raise InvalidField(msg) + + # schema: Type[BaseModel] = get_schema_from_type(key, self.app) + self._get_schema(resource_type) + + # for field_name in field_names: + # if field_name not in schema.__fields__: + # msg = "{schema} has no attribute {field}".format( + # schema=schema.__name__, + # field=field_name, + # ) + # raise InvalidField(msg) return fields + def _get_schema(self, resource_type: str) -> Type[BaseModel]: + target_router = RoutersJSONAPI.all_jsonapi_routers[resource_type] + return target_router.detail_response_schema + def get_sorts(self, schema: Type["TypeSchema"]) -> List[Dict[str, str]]: """ Return fields to sort by including sort name for SQLAlchemy and row sort parameter for other ORMs. diff --git a/fastapi_jsonapi/schema_builder.py b/fastapi_jsonapi/schema_builder.py index 3db08eeb..43b5caed 100644 --- a/fastapi_jsonapi/schema_builder.py +++ b/fastapi_jsonapi/schema_builder.py @@ -484,6 +484,7 @@ def create_jsonapi_object_schemas( base_name: str = "", compute_included_schemas: bool = False, use_schema_cache: bool = True, + exclude_attributes: Optional[List[str]] = None, ) -> JSONAPIObjectSchemas: if use_schema_cache and schema in self.object_schemas_cache and includes is not_passed: return self.object_schemas_cache[schema] diff --git a/fastapi_jsonapi/views/list_view.py b/fastapi_jsonapi/views/list_view.py index e7be5421..3893178f 100644 --- a/fastapi_jsonapi/views/list_view.py +++ b/fastapi_jsonapi/views/list_view.py @@ -6,6 +6,7 @@ JSONAPIResultDetailSchema, JSONAPIResultListSchema, ) +from fastapi_jsonapi.views.utils import get_includes_indexes_by_type from fastapi_jsonapi.views.view_base import ViewBase if TYPE_CHECKING: @@ -14,6 +15,31 @@ logger = logging.getLogger(__name__) +def calculate_include_fields(response, query_params, jsonapi) -> Dict: + included = "included" in response.__fields__ and response.included or [] + + include_params = { + field_name: {*response.__fields__[field_name].type_.__fields__.keys()} + for field_name in response.__fields__ + if field_name + } + include_params["included"] = {} + + includes_indexes_by_type = get_includes_indexes_by_type(included) + + for resource_type, field_names in query_params.fields.items(): + if resource_type == jsonapi.type_: + include_params["data"] = {"__all__": {"attributes": field_names, "id": {"id"}, "type": {"type"}}} + continue + + target_type_indexes = includes_indexes_by_type.get(resource_type) + + if resource_type in includes_indexes_by_type and target_type_indexes: + include_params["included"].update((idx, field_names) for idx in target_type_indexes) + + return include_params + + class ListViewBase(ViewBase): def _calculate_total_pages(self, db_items_count: int) -> int: total_pages = 1 @@ -40,7 +66,17 @@ async def handle_get_resource_list(self, **extra_view_deps) -> JSONAPIResultList count, items_from_db = await dl.get_collection(qs=query_params) total_pages = self._calculate_total_pages(count) - return self._build_list_response(items_from_db, count, total_pages) + response = self._build_list_response(items_from_db, count, total_pages) + + if not query_params.fields: + return response + + include_params = calculate_include_fields(response, query_params, self.jsonapi) + + if include_params: + return response.dict(include=include_params) + + return response async def handle_post_resource_list( self, diff --git a/fastapi_jsonapi/views/utils.py b/fastapi_jsonapi/views/utils.py index 5f80af1a..e82884d6 100644 --- a/fastapi_jsonapi/views/utils.py +++ b/fastapi_jsonapi/views/utils.py @@ -1,6 +1,7 @@ +from collections import defaultdict from enum import Enum from functools import cache -from typing import Callable, Coroutine, Optional, Set, Type, Union +from typing import Callable, Coroutine, Dict, List, Optional, Set, Type, Union from pydantic import BaseModel @@ -27,3 +28,12 @@ class Config: @property def handler(self) -> Optional[Union[Callable, Coroutine]]: return self.prepare_data_layer_kwargs + + +def get_includes_indexes_by_type(included: List[Dict]) -> Dict[str, List[int]]: + result = defaultdict(list) + + for idx, item in enumerate(included, 1): + result[item["type"]].append(idx) + + return result diff --git a/fastapi_jsonapi/views/view_base.py b/fastapi_jsonapi/views/view_base.py index 9ef8dee7..5566d7ab 100644 --- a/fastapi_jsonapi/views/view_base.py +++ b/fastapi_jsonapi/views/view_base.py @@ -30,6 +30,7 @@ from fastapi_jsonapi.schema import ( JSONAPIObjectSchema, JSONAPIResultListMetaSchema, + JSONAPIResultListSchema, get_related_schema, ) from fastapi_jsonapi.schema_base import BaseModel, RelationshipInfo @@ -185,7 +186,12 @@ def _build_detail_response(self, db_item: TypeModel): return detail_jsonapi_schema(data=result_object, **extras) - def _build_list_response(self, items_from_db: List[TypeModel], count: int, total_pages: int): + def _build_list_response( + self, + items_from_db: List[TypeModel], + count: int, + total_pages: int, + ) -> JSONAPIResultListSchema: result_objects, object_schemas, extras = self._build_response(items_from_db, self.jsonapi.schema_list) # we need to build a new schema here diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index d3c31d92..690faa15 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -4,7 +4,7 @@ from datetime import datetime, timezone from itertools import chain, zip_longest from json import dumps, loads -from typing import Dict, List, Literal +from typing import Dict, List, Literal, Set, Tuple from uuid import UUID, uuid4 import pytest @@ -16,6 +16,7 @@ from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import InstrumentedAttribute +from starlette.datastructures import QueryParams from fastapi_jsonapi.views.view_base import ViewBase from tests.common import is_postgres_tests @@ -151,6 +152,59 @@ async def test_get_users_paginated( "meta": {"count": 2, "totalPages": 2}, } + @mark.parametrize( + "fields, expected_include", + [ + param( + [ + ("fields[user]", "name,age"), + ], + {"name", "age"}, + ), + param( + [ + ("fields[user]", "name,age"), + ("fields[user]", "email"), + ], + {"name", "age", "email"}, + ), + ], + ) + async def test_select_custom_fields( + self, + app: FastAPI, + client: AsyncClient, + user_1: User, + user_2: User, + fields: List[Tuple[str, str]], + expected_include: Set[str], + ): + url = app.url_path_for("get_user_list") + user_1, user_2 = sorted((user_1, user_2), key=lambda x: x.id) + + params = QueryParams(fields) + response = await client.get(url, params=str(params)) + + assert response.status_code == status.HTTP_200_OK, response.text + response_data = response.json() + + assert response_data == { + "data": [ + { + "attributes": UserAttributesBaseSchema.from_orm(user_1).dict(include=expected_include), + "id": str(user_1.id), + "type": "user", + }, + { + "attributes": UserAttributesBaseSchema.from_orm(user_2).dict(include=expected_include), + "id": str(user_2.id), + "type": "user", + }, + ], + "jsonapi": {"version": "1.0"}, + "meta": {"count": 2, "total_pages": 1}, + } + class TestCreatePostAndComments: async def test_get_posts_with_users( From 7697f2c2f70eda0d0bbfcf6e11514151a27ab964 Mon Sep 17 00:00:00 2001 From: German Bernadskiy Date: Thu, 7 Mar 2024 16:05:33 +1000 Subject: [PATCH 2/7] updated logic to exclude --- fastapi_jsonapi/querystring.py | 74 +++++++---- fastapi_jsonapi/views/list_view.py | 42 +----- fastapi_jsonapi/views/utils.py | 125 +++++++++++++++++- tests/fixtures/entities.py | 59 +++++++-- tests/test_api/test_api_sqla_with_includes.py | 108 ++++++++++++++- 5 files changed, 324 insertions(+), 84 deletions(-) diff --git a/fastapi_jsonapi/querystring.py b/fastapi_jsonapi/querystring.py index 9731c46c..322591ec 100644 --- a/fastapi_jsonapi/querystring.py +++ b/fastapi_jsonapi/querystring.py @@ -8,7 +8,6 @@ List, Optional, Type, - Union, ) from urllib.parse import unquote @@ -26,9 +25,11 @@ from fastapi_jsonapi.api import RoutersJSONAPI from fastapi_jsonapi.exceptions import ( BadRequest, + InvalidField, InvalidFilters, InvalidInclude, InvalidSort, + InvalidType, ) from fastapi_jsonapi.schema import ( get_model_field, @@ -89,7 +90,16 @@ def __init__(self, request: Request) -> None: self.MAX_INCLUDE_DEPTH: int = self.config.get("MAX_INCLUDE_DEPTH", 3) self.headers: HeadersQueryStringManager = HeadersQueryStringManager(**dict(self.request.headers)) - def _get_key_values(self, name: str) -> Dict[str, Union[List[str], str]]: + def _extract_item_key(self, key: str) -> str: + try: + key_start = key.index("[") + 1 + key_end = key.index("]") + return key[key_start:key_end] + except Exception: + msg = "Parse error" + raise BadRequest(msg, parameter=key) + + def _get_unique_key_values(self, name: str) -> Dict[str, str]: """ Return a dict containing key / values items for a given key, used for items like filters, page, etc. @@ -97,22 +107,28 @@ def _get_key_values(self, name: str) -> Dict[str, Union[List[str], str]]: :return: a dict of key / values items :raises BadRequest: if an error occurred while parsing the querystring. """ - results = defaultdict(set) + results = {} for raw_key, value in self.qs.multi_items(): key = unquote(raw_key) - try: - if not key.startswith(name): - continue + if not key.startswith(name): + continue - key_start = key.index("[") + 1 - key_end = key.index("]") - item_key = key[key_start:key_end] + item_key = self._extract_item_key(key) + results[item_key] = value - results[item_key].update(value.split(",")) - except Exception: - msg = "Parse error" - raise BadRequest(msg, parameter=key) + return results + + def _get_multiple_key_values(self, name: str) -> Dict[str, List]: + results = defaultdict(list) + + for raw_key, value in self.qs.multi_items(): + key = unquote(raw_key) + if not key.startswith(name): + continue + + item_key = self._extract_item_key(key) + results[item_key].extend(value.split(",")) return results @@ -131,7 +147,7 @@ def querystring(self) -> Dict[str, str]: return { key: value for (key, value) in self.qs.multi_items() - if key.startswith(self.managed_keys) or self._get_key_values("filter[") + if key.startswith(self.managed_keys) or self._get_unique_key_values("filter[") } @property @@ -156,8 +172,8 @@ def filters(self) -> List[dict]: raise InvalidFilters(msg) results.extend(loaded_filters) - if self._get_key_values("filter["): - results.extend(self._simple_filters(self._get_key_values("filter["))) + if filter_key_values := self._get_unique_key_values("filter["): + results.extend(self._simple_filters(filter_key_values)) return results @cached_property @@ -183,7 +199,7 @@ def pagination(self) -> PaginationQueryStringManager: :raises BadRequest: if the client is not allowed to disable pagination. """ # check values type - pagination_data: Dict[str, Union[List[str], str]] = self._get_key_values("page") + pagination_data: Dict[str, str] = self._get_unique_key_values("page") pagination = PaginationQueryStringManager(**pagination_data) if pagination_data.get("size") is None: pagination.size = None @@ -213,23 +229,27 @@ def fields(self) -> Dict[str, List[str]]: :raises InvalidField: if result field not in schema. """ - fields = self._get_key_values("fields") + fields = self._get_multiple_key_values("fields") for resource_type, field_names in fields.items(): # TODO: we have registry for models (BaseModel) # TODO: create `type to schemas` registry - # schema: Type[BaseModel] = get_schema_from_type(key, self.app) + if resource_type not in RoutersJSONAPI.all_jsonapi_routers: + msg = f"Application has no resource with type {resource_type!r}" + raise InvalidType(msg) + + schema: Type[BaseModel] = RoutersJSONAPI.all_jsonapi_routers[resource_type]._schema self._get_schema(resource_type) - # for field_name in field_names: - # if field_name not in schema.__fields__: - # msg = "{schema} has no attribute {field}".format( - # schema=schema.__name__, - # field=field_name, - # ) - # raise InvalidField(msg) + for field_name in field_names: + if field_name not in schema.__fields__: + msg = "{schema} has no attribute {field}".format( + schema=schema.__name__, + field=field_name, + ) + raise InvalidField(msg) - return fields + return {resource_type: set(field_names) for resource_type, field_names in fields.items()} def _get_schema(self, resource_type: str) -> Type[BaseModel]: target_router = RoutersJSONAPI.all_jsonapi_routers[resource_type] diff --git a/fastapi_jsonapi/views/list_view.py b/fastapi_jsonapi/views/list_view.py index 3893178f..82a4e1ac 100644 --- a/fastapi_jsonapi/views/list_view.py +++ b/fastapi_jsonapi/views/list_view.py @@ -1,12 +1,12 @@ import logging -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Union from fastapi_jsonapi.schema import ( BaseJSONAPIItemInSchema, JSONAPIResultDetailSchema, JSONAPIResultListSchema, ) -from fastapi_jsonapi.views.utils import get_includes_indexes_by_type +from fastapi_jsonapi.views.utils import handle_fields from fastapi_jsonapi.views.view_base import ViewBase if TYPE_CHECKING: @@ -15,31 +15,6 @@ logger = logging.getLogger(__name__) -def calculate_include_fields(response, query_params, jsonapi) -> Dict: - included = "included" in response.__fields__ and response.included or [] - - include_params = { - field_name: {*response.__fields__[field_name].type_.__fields__.keys()} - for field_name in response.__fields__ - if field_name - } - include_params["included"] = {} - - includes_indexes_by_type = get_includes_indexes_by_type(included) - - for resource_type, field_names in query_params.fields.items(): - if resource_type == jsonapi.type_: - include_params["data"] = {"__all__": {"attributes": field_names, "id": {"id"}, "type": {"type"}}} - continue - - target_type_indexes = includes_indexes_by_type.get(resource_type) - - if resource_type in includes_indexes_by_type and target_type_indexes: - include_params["included"].update((idx, field_names) for idx in target_type_indexes) - - return include_params - - class ListViewBase(ViewBase): def _calculate_total_pages(self, db_items_count: int) -> int: total_pages = 1 @@ -60,23 +35,14 @@ async def get_data_layer( ) -> "BaseDataLayer": return await self.get_data_layer_for_list(extra_view_deps) - async def handle_get_resource_list(self, **extra_view_deps) -> JSONAPIResultListSchema: + async def handle_get_resource_list(self, **extra_view_deps) -> Union[JSONAPIResultListSchema, Dict]: dl: "BaseDataLayer" = await self.get_data_layer(extra_view_deps) query_params = self.query_params count, items_from_db = await dl.get_collection(qs=query_params) total_pages = self._calculate_total_pages(count) response = self._build_list_response(items_from_db, count, total_pages) - - if not query_params.fields: - return response - - include_params = calculate_include_fields(response, query_params, self.jsonapi) - - if include_params: - return response.dict(include=include_params) - - return response + return handle_fields(response, query_params, self.jsonapi) async def handle_post_resource_list( self, diff --git a/fastapi_jsonapi/views/utils.py b/fastapi_jsonapi/views/utils.py index e82884d6..8be31129 100644 --- a/fastapi_jsonapi/views/utils.py +++ b/fastapi_jsonapi/views/utils.py @@ -1,9 +1,39 @@ +from __future__ import annotations + from collections import defaultdict from enum import Enum from functools import cache -from typing import Callable, Coroutine, Dict, List, Optional, Set, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Dict, + Iterable, + List, + Optional, + Set, + Type, + Union, +) from pydantic import BaseModel +from pydantic.fields import ModelField + +from fastapi_jsonapi.data_typing import TypeSchema +from fastapi_jsonapi.schema import JSONAPIObjectSchema +from fastapi_jsonapi.schema_builder import ( + JSONAPIResultDetailSchema, + JSONAPIResultListSchema, +) + +if TYPE_CHECKING: + from fastapi_jsonapi.api import RoutersJSONAPI + from fastapi_jsonapi.querystring import QueryStringManager + + +JSONAPIResponse = Union[JSONAPIResultDetailSchema, JSONAPIResultListSchema] +IGNORE_ALL_FIELDS_LITERAL = "" class HTTPMethod(Enum): @@ -30,10 +60,97 @@ def handler(self) -> Optional[Union[Callable, Coroutine]]: return self.prepare_data_layer_kwargs -def get_includes_indexes_by_type(included: List[Dict]) -> Dict[str, List[int]]: +def _get_includes_indexes_by_type(included: List[JSONAPIObjectSchema]) -> Dict[str, List[int]]: result = defaultdict(list) - for idx, item in enumerate(included, 1): - result[item["type"]].append(idx) + for idx, item in enumerate(included): + result[item.type].append(idx) return result + + +# TODO: move to schema builder? +def _is_relationship_field(field: ModelField) -> bool: + return "relationship" in field.field_info.extra + + +def _get_schema_field_names(schema: Type[TypeSchema]) -> Set[str]: + """ + Returns all attribute names except relationships + """ + result = set() + + for field_name, field in schema.__fields__.items(): + if _is_relationship_field(field): + continue + + result.add(field_name) + + return result + + +def _get_exclude_fields( + schema: Type[TypeSchema], + include_fields: Iterable[str], +) -> Set[str]: + schema_fields = _get_schema_field_names(schema) + + if IGNORE_ALL_FIELDS_LITERAL in include_fields: + return schema_fields + + return set(_get_schema_field_names(schema)).difference(include_fields) + + +def _calculate_exclude_fields( + response: JSONAPIResponse, + query_params: QueryStringManager, + jsonapi: RoutersJSONAPI, +) -> Dict: + included = "included" in response.__fields__ and response.included or [] + is_list_response = isinstance(response, JSONAPIResultListSchema) + + exclude_params: Dict[str, Any] = {} + + includes_indexes_by_type = _get_includes_indexes_by_type(included) + + for resource_type, field_names in query_params.fields.items(): + schema = jsonapi.all_jsonapi_routers[resource_type]._schema + exclude_fields = _get_exclude_fields(schema, include_fields=field_names) + attributes_exclude = {"attributes": exclude_fields} + + if resource_type == jsonapi.type_: + if is_list_response: + exclude_params["data"] = {"__all__": attributes_exclude} + else: + exclude_params["data"] = attributes_exclude + + continue + + if not included: + continue + + target_type_indexes = includes_indexes_by_type.get(resource_type) + + if target_type_indexes: + if "included" not in exclude_params: + exclude_params["included"] = {} + + exclude_params["included"].update((idx, attributes_exclude) for idx in target_type_indexes) + + return exclude_params + + +def handle_fields( + response: JSONAPIResponse, + query_params: QueryStringManager, + jsonapi: RoutersJSONAPI, +) -> Union[JSONAPIResponse, Dict]: + if not query_params.fields: + return response + + exclude_params = _calculate_exclude_fields(response, query_params, jsonapi) + + if exclude_params: + return response.dict(exclude=exclude_params) + + return response diff --git a/tests/fixtures/entities.py b/tests/fixtures/entities.py index 23eb3547..f5f45970 100644 --- a/tests/fixtures/entities.py +++ b/tests/fixtures/entities.py @@ -99,20 +99,31 @@ async def user_2_bio(async_session: AsyncSession, user_2: User) -> UserBio: ) +async def build_post(async_session: AsyncSession, user: User, **fields) -> Post: + fields = {"title": fake.name(), "body": fake.sentence(), **fields} + post = Post(user=user, **fields) + async_session.add(post) + await async_session.commit() + return post + + @async_fixture() -async def user_1_posts(async_session: AsyncSession, user_1: User): - posts = [Post(title=f"post_u1_{i}", user=user_1) for i in range(1, 4)] +async def user_1_posts(async_session: AsyncSession, user_1: User) -> List[Post]: + posts = [ + Post( + title=f"post_u1_{i}", + user=user_1, + body=fake.sentence(), + ) + for i in range(1, 4) + ] async_session.add_all(posts) await async_session.commit() for post in posts: await async_session.refresh(post) - yield posts - - for post in posts: - await async_session.delete(post) - await async_session.commit() + return posts @async_fixture() @@ -130,19 +141,22 @@ async def user_1_post(async_session: AsyncSession, user_1: User): @async_fixture() -async def user_2_posts(async_session: AsyncSession, user_2: User): - posts = [Post(title=f"post_u2_{i}", user=user_2) for i in range(1, 5)] +async def user_2_posts(async_session: AsyncSession, user_2: User) -> List[Post]: + posts = [ + Post( + title=f"post_u2_{i}", + user=user_2, + body=fake.sentence(), + ) + for i in range(1, 5) + ] async_session.add_all(posts) await async_session.commit() for post in posts: await async_session.refresh(post) - yield posts - - for post in posts: - await async_session.delete(post) - await async_session.commit() + return posts @async_fixture() @@ -213,6 +227,23 @@ async def factory(name: str | None = None) -> Computer: return factory +async def build_post_comment( + async_session: AsyncSession, + user: User, + post: Post, + **fields, +) -> PostComment: + fields = {"text": fake.sentence(), **fields} + post_comment = PostComment( + author=user, + post=post, + **fields, + ) + async_session.add(post_comment) + await async_session.commit() + return post_comment + + @async_fixture() async def user_2_comment_for_one_u1_post(async_session: AsyncSession, user_2, user_1_post_for_comments): post = user_1_post_for_comments diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index 690faa15..688e171f 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -21,7 +21,12 @@ from fastapi_jsonapi.views.view_base import ViewBase from tests.common import is_postgres_tests from tests.fixtures.app import build_alphabet_app, build_app_custom -from tests.fixtures.entities import build_workplace, create_user +from tests.fixtures.entities import ( + build_post, + build_post_comment, + build_workplace, + create_user, +) from tests.misc.utils import fake from tests.models import ( Alpha, @@ -205,6 +210,107 @@ async def test_select_custom_fields( "meta": {"count": 2, "total_pages": 1}, } + async def test_select_custom_fields_with_includes( + self, + app: FastAPI, + async_session: AsyncSession, + client: AsyncClient, + user_1: User, + user_2: User, + ): + url = app.url_path_for("get_user_list") + user_1, user_2 = sorted((user_1, user_2), key=lambda x: x.id) + + user_2_post = await build_post(async_session, user_2) + user_1_post = await build_post(async_session, user_1) + + user_1_comment = await build_post_comment(async_session, user_1, user_2_post) + user_2_comment = await build_post_comment(async_session, user_2, user_1_post) + + params = QueryParams( + [ + ("fields[user]", "name"), + ("fields[post]", "title"), + # empty str means ignore all fields + ("fields[comment]", ""), + ("include", "posts,posts.comments"), + ("sort", "id"), + ], + ) + response = await client.get(url, params=str(params)) + + assert response.status_code == status.HTTP_200_OK, response.text + response_data = response.json() + response_data["included"] = sorted(response_data["included"], key=lambda x: (x["type"], x["id"])) + + assert response_data == { + "data": [ + { + "attributes": UserAttributesBaseSchema.from_orm(user_1).dict(include={"name"}), + "relationships": { + "posts": { + "data": [ + { + "id": str(user_1_post.id), + "type": "post", + }, + ], + }, + }, + "id": str(user_1.id), + "type": "user", + }, + { + "attributes": UserAttributesBaseSchema.from_orm(user_2).dict(include={"name"}), + "relationships": { + "posts": { + "data": [ + { + "id": str(user_2_post.id), + "type": "post", + }, + ], + }, + }, + "id": str(user_2.id), + "type": "user", + }, + ], + "jsonapi": {"version": "1.0"}, + "meta": {"count": 2, "total_pages": 1}, + "included": sorted( + [ + { + "attributes": PostAttributesBaseSchema.from_orm(user_2_post).dict(include={"title"}), + "id": str(user_2_post.id), + "relationships": { + "comments": {"data": [{"id": str(user_1_comment.id), "type": "post_comment"}]}, + }, + "type": "post", + }, + { + "attributes": PostAttributesBaseSchema.from_orm(user_1_post).dict(include={"title"}), + "id": str(user_1_post.id), + "relationships": { + "comments": {"data": [{"id": str(user_2_comment.id), "type": "post_comment"}]}, + }, + "type": "post", + }, + { + "attributes": {}, + "id": "1", + "type": "post_comment", + }, + { + "attributes": {}, + "id": "2", + "type": "post_comment", + }, + ], + key=lambda x: (x["type"], x["id"]), + ), + } + class TestCreatePostAndComments: async def test_get_posts_with_users( From 11298015c6fc2f53db7629f752388cfae2cb2846 Mon Sep 17 00:00:00 2001 From: German Bernadskiy Date: Thu, 7 Mar 2024 16:18:46 +1000 Subject: [PATCH 3/7] fix --- fastapi_jsonapi/querystring.py | 3 +++ tests/fixtures/app.py | 2 +- tests/test_api/test_api_sqla_with_includes.py | 25 ++++++++++++------- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/fastapi_jsonapi/querystring.py b/fastapi_jsonapi/querystring.py index 322591ec..1d1c752c 100644 --- a/fastapi_jsonapi/querystring.py +++ b/fastapi_jsonapi/querystring.py @@ -242,6 +242,9 @@ def fields(self) -> Dict[str, List[str]]: self._get_schema(resource_type) for field_name in field_names: + if field_name == "": + continue + if field_name not in schema.__fields__: msg = "{schema} has no attribute {field}".format( schema=schema.__name__, diff --git a/tests/fixtures/app.py b/tests/fixtures/app.py index 1bf87887..b0a68075 100644 --- a/tests/fixtures/app.py +++ b/tests/fixtures/app.py @@ -112,7 +112,7 @@ def add_routers(app_plain: FastAPI): class_detail=DetailViewBaseGeneric, class_list=ListViewBaseGeneric, schema=PostCommentSchema, - resource_type="comment", + resource_type="post_comment", model=PostComment, ) diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index 688e171f..f01e7e3d 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -232,7 +232,7 @@ async def test_select_custom_fields_with_includes( ("fields[user]", "name"), ("fields[post]", "title"), # empty str means ignore all fields - ("fields[comment]", ""), + ("fields[post_comment]", ""), ("include", "posts,posts.comments"), ("sort", "id"), ], @@ -284,7 +284,14 @@ async def test_select_custom_fields_with_includes( "attributes": PostAttributesBaseSchema.from_orm(user_2_post).dict(include={"title"}), "id": str(user_2_post.id), "relationships": { - "comments": {"data": [{"id": str(user_1_comment.id), "type": "post_comment"}]}, + "comments": { + "data": [ + { + "id": str(user_1_comment.id), + "type": "post_comment", + }, + ], + }, }, "type": "post", }, @@ -298,12 +305,12 @@ async def test_select_custom_fields_with_includes( }, { "attributes": {}, - "id": "1", + "id": str(user_1_comment.id), "type": "post_comment", }, { "attributes": {}, - "id": "2", + "id": str(user_2_comment.id), "type": "post_comment", }, ], @@ -426,7 +433,7 @@ async def test_create_comments_for_post( user_2: User, user_1_post: Post, ): - url = app.url_path_for("get_comment_list") + url = app.url_path_for("get_post_comment_list") url = f"{url}?include=author,post,post.user" comment_attributes = PostCommentAttributesBaseSchema( text=fake.sentence(), @@ -457,7 +464,7 @@ async def test_create_comments_for_post( comment_id = comment_data.pop("id") assert comment_id assert comment_data == { - "type": "comment", + "type": "post_comment", "attributes": comment_attributes, "relationships": { "post": { @@ -515,7 +522,7 @@ async def test_create_comment_error_no_relationship( :param user_1_post: :return: """ - url = app.url_path_for("get_comment_list") + url = app.url_path_for("get_post_comment_list") comment_attributes = PostCommentAttributesBaseSchema( text=fake.sentence(), ).dict() @@ -556,7 +563,7 @@ async def test_create_comment_error_no_relationships_content( app: FastAPI, client: AsyncClient, ): - url = app.url_path_for("get_comment_list") + url = app.url_path_for("get_post_comment_list") comment_attributes = PostCommentAttributesBaseSchema( text=fake.sentence(), ).dict() @@ -602,7 +609,7 @@ async def test_create_comment_error_no_relationships_field( app: FastAPI, client: AsyncClient, ): - url = app.url_path_for("get_comment_list") + url = app.url_path_for("get_post_comment_list") comment_attributes = PostCommentAttributesBaseSchema( text=fake.sentence(), ).dict() From b5879bc49d923e568ef6d0da9687f0f37bd94b92 Mon Sep 17 00:00:00 2001 From: German Bernadskiy Date: Thu, 7 Mar 2024 18:21:35 +1000 Subject: [PATCH 4/7] added fields tests for other crud operations --- fastapi_jsonapi/views/detail_view.py | 11 +- fastapi_jsonapi/views/list_view.py | 12 +- fastapi_jsonapi/views/utils.py | 4 +- tests/test_api/test_api_sqla_with_includes.py | 119 +++++++++++++++++- 4 files changed, 133 insertions(+), 13 deletions(-) diff --git a/fastapi_jsonapi/views/detail_view.py b/fastapi_jsonapi/views/detail_view.py index d80353d5..712f2c17 100644 --- a/fastapi_jsonapi/views/detail_view.py +++ b/fastapi_jsonapi/views/detail_view.py @@ -12,6 +12,7 @@ BaseJSONAPIItemInSchema, JSONAPIResultDetailSchema, ) +from fastapi_jsonapi.views.utils import handle_jsonapi_fields from fastapi_jsonapi.views.view_base import ViewBase if TYPE_CHECKING: @@ -34,22 +35,24 @@ async def handle_get_resource_detail( self, object_id: Union[int, str], **extra_view_deps, - ): + ) -> Union[JSONAPIResultDetailSchema, Dict]: dl: "BaseDataLayer" = await self.get_data_layer(extra_view_deps) view_kwargs = {dl.url_id_field: object_id} db_object = await dl.get_object(view_kwargs=view_kwargs, qs=self.query_params) - return self._build_detail_response(db_object) + response = self._build_detail_response(db_object) + return handle_jsonapi_fields(response, self.query_params, self.jsonapi) async def handle_update_resource( self, obj_id: str, data_update: BaseJSONAPIItemInSchema, **extra_view_deps, - ) -> JSONAPIResultDetailSchema: + ) -> Union[JSONAPIResultDetailSchema, Dict]: dl: "BaseDataLayer" = await self.get_data_layer(extra_view_deps) - return await self.process_update_object(dl=dl, obj_id=obj_id, data_update=data_update) + response = await self.process_update_object(dl=dl, obj_id=obj_id, data_update=data_update) + return handle_jsonapi_fields(response, self.query_params, self.jsonapi) async def process_update_object( self, diff --git a/fastapi_jsonapi/views/list_view.py b/fastapi_jsonapi/views/list_view.py index 82a4e1ac..e6fc59a1 100644 --- a/fastapi_jsonapi/views/list_view.py +++ b/fastapi_jsonapi/views/list_view.py @@ -6,7 +6,7 @@ JSONAPIResultDetailSchema, JSONAPIResultListSchema, ) -from fastapi_jsonapi.views.utils import handle_fields +from fastapi_jsonapi.views.utils import handle_jsonapi_fields from fastapi_jsonapi.views.view_base import ViewBase if TYPE_CHECKING: @@ -42,15 +42,16 @@ async def handle_get_resource_list(self, **extra_view_deps) -> Union[JSONAPIResu total_pages = self._calculate_total_pages(count) response = self._build_list_response(items_from_db, count, total_pages) - return handle_fields(response, query_params, self.jsonapi) + return handle_jsonapi_fields(response, query_params, self.jsonapi) async def handle_post_resource_list( self, data_create: BaseJSONAPIItemInSchema, **extra_view_deps, - ) -> JSONAPIResultDetailSchema: + ) -> Union[JSONAPIResultDetailSchema, Dict]: dl: "BaseDataLayer" = await self.get_data_layer(extra_view_deps) - return await self.process_create_object(dl=dl, data_create=data_create) + response = await self.process_create_object(dl=dl, data_create=data_create) + return handle_jsonapi_fields(response, self.query_params, self.jsonapi) async def process_create_object(self, dl: "BaseDataLayer", data_create: BaseJSONAPIItemInSchema): created_object = await dl.create_object(data_create=data_create, view_kwargs={}) @@ -70,4 +71,5 @@ async def handle_delete_resource_list(self, **extra_view_deps) -> JSONAPIResultL await dl.delete_objects(items_from_db, {}) - return self._build_list_response(items_from_db, count, total_pages) + response = self._build_list_response(items_from_db, count, total_pages) + return handle_jsonapi_fields(response, self.query_params, self.jsonapi) diff --git a/fastapi_jsonapi/views/utils.py b/fastapi_jsonapi/views/utils.py index 8be31129..e521d773 100644 --- a/fastapi_jsonapi/views/utils.py +++ b/fastapi_jsonapi/views/utils.py @@ -140,7 +140,7 @@ def _calculate_exclude_fields( return exclude_params -def handle_fields( +def handle_jsonapi_fields( response: JSONAPIResponse, query_params: QueryStringManager, jsonapi: RoutersJSONAPI, @@ -151,6 +151,6 @@ def handle_fields( exclude_params = _calculate_exclude_fields(response, query_params, jsonapi) if exclude_params: - return response.dict(exclude=exclude_params) + return response.dict(exclude=exclude_params, by_alias=True) return response diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index f01e7e3d..7cea9569 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -207,7 +207,7 @@ async def test_select_custom_fields( }, ], "jsonapi": {"version": "1.0"}, - "meta": {"count": 2, "total_pages": 1}, + "meta": {"count": 2, "totalPages": 1}, } async def test_select_custom_fields_with_includes( @@ -277,7 +277,7 @@ async def test_select_custom_fields_with_includes( }, ], "jsonapi": {"version": "1.0"}, - "meta": {"count": 2, "total_pages": 1}, + "meta": {"count": 2, "totalPages": 1}, "included": sorted( [ { @@ -780,6 +780,32 @@ async def test_many_to_many_load_inner_includes_to_parents( assert ("child", ViewBase.get_db_item_id(child_4)) not in included_data +class TestGetUserDetail: + def get_url(self, app: FastAPI, user_id: int) -> str: + return app.url_path_for("get_user_detail", obj_id=user_id) + + async def test_select_custom_fields( + self, + app: FastAPI, + client: AsyncClient, + user_1: User, + ): + url = self.get_url(app, user_1.id) + params = QueryParams([("fields[user]", "name,age")]) + response = await client.get(url, params=params) + + assert response.status_code == status.HTTP_200_OK + assert response.json() == { + "data": { + "attributes": UserAttributesBaseSchema.from_orm(user_1).dict(include={"name", "age"}), + "id": str(user_1.id), + "type": "user", + }, + "jsonapi": {"version": "1.0"}, + "meta": None, + } + + class TestUserWithPostsWithInnerIncludes: @mark.parametrize( "include, expected_relationships_inner_relations, expect_user_include", @@ -1503,6 +1529,34 @@ class ContainsTimestampAttrsSchema(BaseModel): "data": [], } + async def test_select_custom_fields(self, app: FastAPI, client: AsyncClient): + user_attrs_schema = UserAttributesBaseSchema( + name=fake.name(), + age=fake.pyint(), + email=fake.email(), + ) + create_user_body = { + "data": { + "attributes": user_attrs_schema.dict(), + }, + } + params = QueryParams([("fields[user]", "name")]) + url = app.url_path_for("get_user_list") + res = await client.post(url, json=create_user_body, params=params) + assert res.status_code == status.HTTP_201_CREATED, res.text + response_data: dict = res.json() + + assert "data" in response_data + assert response_data["data"].pop("id") + assert response_data == { + "data": { + "attributes": user_attrs_schema.dict(include={"name"}), + "type": "user", + }, + "jsonapi": {"version": "1.0"}, + "meta": None, + } + class TestPatchObjects: async def test_patch_object( @@ -1615,6 +1669,39 @@ async def test_update_schema_has_extra_fields(self, user_1: User, caplog): ): assert expected in log_message + async def test_select_custom_fields( + self, + app: FastAPI, + client: AsyncClient, + user_1: User, + ): + new_attrs = UserAttributesBaseSchema( + name=fake.name(), + age=fake.pyint(), + email=fake.email(), + ) + + patch_user_body = { + "data": { + "id": user_1.id, + "attributes": new_attrs.dict(), + }, + } + params = QueryParams([("fields[user]", "name")]) + url = app.url_path_for("get_user_detail", obj_id=user_1.id) + res = await client.patch(url, params=params, json=patch_user_body) + + assert res.status_code == status.HTTP_200_OK, res.text + assert res.json() == { + "data": { + "attributes": new_attrs.dict(include={"name"}), + "id": str(user_1.id), + "type": "user", + }, + "jsonapi": {"version": "1.0"}, + "meta": None, + } + class TestPatchObjectRelationshipsToOne: async def test_ok_when_foreign_key_of_related_object_is_nullable( @@ -2085,6 +2172,34 @@ async def test_delete_objects_many( "meta": {"count": 1, "totalPages": 1}, } + async def test_select_custom_fields( + self, + app: FastAPI, + client: AsyncClient, + user_1: User, + user_2: User, + ): + params = QueryParams([("fields[user]", "name")]) + url = app.url_path_for("get_user_list") + res = await client.delete(url, params=params) + assert res.status_code == status.HTTP_200_OK, res.text + assert res.json() == { + "data": [ + { + "attributes": UserAttributesBaseSchema.from_orm(user_1).dict(include={"name"}), + "id": str(user_1.id), + "type": "user", + }, + { + "attributes": UserAttributesBaseSchema.from_orm(user_2).dict(include={"name"}), + "id": str(user_2.id), + "type": "user", + }, + ], + "jsonapi": {"version": "1.0"}, + "meta": {"count": 2, "totalPages": 1}, + } + class TestOpenApi: def test_openapi_method_ok(self, app: FastAPI): From e09170266e7c3be1c5fda95140b2260d6b325502 Mon Sep 17 00:00:00 2001 From: German Bernadskiy Date: Thu, 7 Mar 2024 19:15:21 +1000 Subject: [PATCH 5/7] added coverage tests --- tests/test_fastapi_jsonapi/__init__.py | 0 .../test_fastapi_jsonapi/test_querystring.py | 92 +++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 tests/test_fastapi_jsonapi/__init__.py create mode 100644 tests/test_fastapi_jsonapi/test_querystring.py diff --git a/tests/test_fastapi_jsonapi/__init__.py b/tests/test_fastapi_jsonapi/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_fastapi_jsonapi/test_querystring.py b/tests/test_fastapi_jsonapi/test_querystring.py new file mode 100644 index 00000000..dec04365 --- /dev/null +++ b/tests/test_fastapi_jsonapi/test_querystring.py @@ -0,0 +1,92 @@ +import json +from unittest.mock import MagicMock + +import pytest +from fastapi import status +from starlette.datastructures import QueryParams + +from fastapi_jsonapi.exceptions import InvalidFilters +from fastapi_jsonapi.exceptions.json_api import BadRequest +from fastapi_jsonapi.querystring import QueryStringManager + + +def test__extract_item_key(): + manager = QueryStringManager(MagicMock()) + + key = "fields[user]" + assert manager._extract_item_key(key) == "user" + + with pytest.raises(BadRequest) as exc_info: # noqa: PT012 + key = "fields[user" + manager._extract_item_key(key) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == { + "errors": [ + { + "status_code": status.HTTP_400_BAD_REQUEST, + "source": {"parameter": "fields[user"}, + "title": "Bad Request", + "detail": "Parse error", + }, + ], + } + + +def test_querystring(): + request = MagicMock() + request.query_params = QueryParams([("fields[user]", "name")]) + manager = QueryStringManager(request) + assert manager.querystring == {"fields[user]": "name"} + + +def test_filters__errors(): + request = MagicMock() + request.query_params = QueryParams([("filter", "not_json")]) + manager = QueryStringManager(request) + + with pytest.raises(InvalidFilters) as exc_info: + manager.filters + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == { + "errors": [ + { + "status_code": status.HTTP_400_BAD_REQUEST, + "source": {"parameter": "filters"}, + "title": "Invalid filters querystring parameter.", + "detail": "Parse error", + }, + ], + } + + request.query_params = QueryParams( + [ + ( + "filter", + json.dumps( + { + "name": "", + "op": "", + "val": "", + }, + ), + ), + ], + ) + manager = QueryStringManager(request) + + with pytest.raises(InvalidFilters) as exc_info: + manager.filters + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == { + "errors": [ + { + "status_code": status.HTTP_400_BAD_REQUEST, + "source": {"parameter": "filters"}, + "title": "Invalid filters querystring parameter.", + "detail": "Incorrect filters format, expected list of conditions but got dict", + }, + ], + } From 43d7dc06c9c85e678a9cc4959769b1333cf1d05b Mon Sep 17 00:00:00 2001 From: German Bernadskiy Date: Mon, 11 Mar 2024 18:26:01 +1000 Subject: [PATCH 6/7] added coverage test --- fastapi_jsonapi/querystring.py | 8 ++---- tests/test_api/test_api_sqla_with_includes.py | 26 +++++++++++++++++++ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/fastapi_jsonapi/querystring.py b/fastapi_jsonapi/querystring.py index 1d1c752c..c68bbba4 100644 --- a/fastapi_jsonapi/querystring.py +++ b/fastapi_jsonapi/querystring.py @@ -212,8 +212,6 @@ def pagination(self) -> PaginationQueryStringManager: return pagination - # TODO: finally use this! upgrade Sqlachemy Data Layer - # and add to all views (get list/detail, create, patch) @property def fields(self) -> Dict[str, List[str]]: """ @@ -238,8 +236,7 @@ def fields(self) -> Dict[str, List[str]]: msg = f"Application has no resource with type {resource_type!r}" raise InvalidType(msg) - schema: Type[BaseModel] = RoutersJSONAPI.all_jsonapi_routers[resource_type]._schema - self._get_schema(resource_type) + schema: Type[BaseModel] = self._get_schema(resource_type) for field_name in field_names: if field_name == "": @@ -255,8 +252,7 @@ def fields(self) -> Dict[str, List[str]]: return {resource_type: set(field_names) for resource_type, field_names in fields.items()} def _get_schema(self, resource_type: str) -> Type[BaseModel]: - target_router = RoutersJSONAPI.all_jsonapi_routers[resource_type] - return target_router.detail_response_schema + return RoutersJSONAPI.all_jsonapi_routers[resource_type]._schema def get_sorts(self, schema: Type["TypeSchema"]) -> List[Dict[str, str]]: """ diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index 7cea9569..3beb9360 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -318,6 +318,32 @@ async def test_select_custom_fields_with_includes( ), } + async def test_select_custom_fields_for_includes_without_requesting_includes( + self, + app: FastAPI, + client: AsyncClient, + user_1: User, + ): + url = app.url_path_for("get_user_list") + + params = QueryParams([("fields[post]", "title")]) + response = await client.get(url, params=str(params)) + + assert response.status_code == status.HTTP_200_OK, response.text + response_data = response.json() + + assert response_data == { + "data": [ + { + "attributes": UserAttributesBaseSchema.from_orm(user_1), + "id": str(user_1.id), + "type": "user", + }, + ], + "jsonapi": {"version": "1.0"}, + "meta": {"count": 1, "totalPages": 1}, + } + class TestCreatePostAndComments: async def test_get_posts_with_users( From ebc59b5cdb00cb7e800e822574b237bff7accfb7 Mon Sep 17 00:00:00 2001 From: German Bernadskiy Date: Wed, 13 Mar 2024 13:52:27 +1000 Subject: [PATCH 7/7] issue update --- fastapi_jsonapi/schema_builder.py | 1 - tests/test_api/test_api_sqla_with_includes.py | 51 +++++++++++++------ 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/fastapi_jsonapi/schema_builder.py b/fastapi_jsonapi/schema_builder.py index 43b5caed..3db08eeb 100644 --- a/fastapi_jsonapi/schema_builder.py +++ b/fastapi_jsonapi/schema_builder.py @@ -484,7 +484,6 @@ def create_jsonapi_object_schemas( base_name: str = "", compute_included_schemas: bool = False, use_schema_cache: bool = True, - exclude_attributes: Optional[List[str]] = None, ) -> JSONAPIObjectSchemas: if use_schema_cache and schema in self.object_schemas_cache and includes is not_passed: return self.object_schemas_cache[schema] diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index 3beb9360..285db092 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -227,10 +227,13 @@ async def test_select_custom_fields_with_includes( user_1_comment = await build_post_comment(async_session, user_1, user_2_post) user_2_comment = await build_post_comment(async_session, user_2, user_1_post) + queried_user_fields = "name" + queried_post_fields = "title" + params = QueryParams( [ - ("fields[user]", "name"), - ("fields[post]", "title"), + ("fields[user]", queried_user_fields), + ("fields[post]", queried_post_fields), # empty str means ignore all fields ("fields[post_comment]", ""), ("include", "posts,posts.comments"), @@ -246,7 +249,9 @@ async def test_select_custom_fields_with_includes( assert response_data == { "data": [ { - "attributes": UserAttributesBaseSchema.from_orm(user_1).dict(include={"name"}), + "attributes": UserAttributesBaseSchema.from_orm(user_1).dict( + include=set(queried_user_fields.split(",")), + ), "relationships": { "posts": { "data": [ @@ -261,7 +266,9 @@ async def test_select_custom_fields_with_includes( "type": "user", }, { - "attributes": UserAttributesBaseSchema.from_orm(user_2).dict(include={"name"}), + "attributes": UserAttributesBaseSchema.from_orm(user_2).dict( + include=set(queried_user_fields.split(",")), + ), "relationships": { "posts": { "data": [ @@ -281,7 +288,9 @@ async def test_select_custom_fields_with_includes( "included": sorted( [ { - "attributes": PostAttributesBaseSchema.from_orm(user_2_post).dict(include={"title"}), + "attributes": PostAttributesBaseSchema.from_orm(user_2_post).dict( + include=set(queried_post_fields.split(",")), + ), "id": str(user_2_post.id), "relationships": { "comments": { @@ -296,7 +305,9 @@ async def test_select_custom_fields_with_includes( "type": "post", }, { - "attributes": PostAttributesBaseSchema.from_orm(user_1_post).dict(include={"title"}), + "attributes": PostAttributesBaseSchema.from_orm(user_1_post).dict( + include=set(queried_post_fields.split(",")), + ), "id": str(user_1_post.id), "relationships": { "comments": {"data": [{"id": str(user_2_comment.id), "type": "post_comment"}]}, @@ -817,13 +828,16 @@ async def test_select_custom_fields( user_1: User, ): url = self.get_url(app, user_1.id) - params = QueryParams([("fields[user]", "name,age")]) + queried_user_fields = "name,age" + params = QueryParams([("fields[user]", queried_user_fields)]) response = await client.get(url, params=params) assert response.status_code == status.HTTP_200_OK assert response.json() == { "data": { - "attributes": UserAttributesBaseSchema.from_orm(user_1).dict(include={"name", "age"}), + "attributes": UserAttributesBaseSchema.from_orm(user_1).dict( + include=set(queried_user_fields.split(",")), + ), "id": str(user_1.id), "type": "user", }, @@ -1566,7 +1580,8 @@ async def test_select_custom_fields(self, app: FastAPI, client: AsyncClient): "attributes": user_attrs_schema.dict(), }, } - params = QueryParams([("fields[user]", "name")]) + queried_user_fields = "name" + params = QueryParams([("fields[user]", queried_user_fields)]) url = app.url_path_for("get_user_list") res = await client.post(url, json=create_user_body, params=params) assert res.status_code == status.HTTP_201_CREATED, res.text @@ -1576,7 +1591,7 @@ async def test_select_custom_fields(self, app: FastAPI, client: AsyncClient): assert response_data["data"].pop("id") assert response_data == { "data": { - "attributes": user_attrs_schema.dict(include={"name"}), + "attributes": user_attrs_schema.dict(include=set(queried_user_fields.split(","))), "type": "user", }, "jsonapi": {"version": "1.0"}, @@ -1713,14 +1728,15 @@ async def test_select_custom_fields( "attributes": new_attrs.dict(), }, } - params = QueryParams([("fields[user]", "name")]) + queried_user_fields = "name" + params = QueryParams([("fields[user]", queried_user_fields)]) url = app.url_path_for("get_user_detail", obj_id=user_1.id) res = await client.patch(url, params=params, json=patch_user_body) assert res.status_code == status.HTTP_200_OK, res.text assert res.json() == { "data": { - "attributes": new_attrs.dict(include={"name"}), + "attributes": new_attrs.dict(include=set(queried_user_fields.split(","))), "id": str(user_1.id), "type": "user", }, @@ -2205,19 +2221,24 @@ async def test_select_custom_fields( user_1: User, user_2: User, ): - params = QueryParams([("fields[user]", "name")]) + queried_user_fields = "name" + params = QueryParams([("fields[user]", queried_user_fields)]) url = app.url_path_for("get_user_list") res = await client.delete(url, params=params) assert res.status_code == status.HTTP_200_OK, res.text assert res.json() == { "data": [ { - "attributes": UserAttributesBaseSchema.from_orm(user_1).dict(include={"name"}), + "attributes": UserAttributesBaseSchema.from_orm(user_1).dict( + include=set(queried_user_fields.split(",")), + ), "id": str(user_1.id), "type": "user", }, { - "attributes": UserAttributesBaseSchema.from_orm(user_2).dict(include={"name"}), + "attributes": UserAttributesBaseSchema.from_orm(user_2).dict( + include=set(queried_user_fields.split(",")), + ), "id": str(user_2.id), "type": "user", },