diff --git a/fastapi_jsonapi/querystring.py b/fastapi_jsonapi/querystring.py index 8503e23b..c68bbba4 100644 --- a/fastapi_jsonapi/querystring.py +++ b/fastapi_jsonapi/querystring.py @@ -1,4 +1,5 @@ """Helper to deal with querystring parameters according to jsonapi specification.""" +from collections import defaultdict from functools import cached_property from typing import ( TYPE_CHECKING, @@ -7,7 +8,6 @@ List, Optional, Type, - Union, ) from urllib.parse import unquote @@ -22,17 +22,18 @@ ) from starlette.datastructures import QueryParams +from fastapi_jsonapi.api import RoutersJSONAPI from fastapi_jsonapi.exceptions import ( BadRequest, InvalidField, InvalidFilters, InvalidInclude, InvalidSort, + InvalidType, ) from fastapi_jsonapi.schema import ( get_model_field, get_relationships, - get_schema_from_type, ) from fastapi_jsonapi.splitter import SPLIT_REL @@ -89,7 +90,16 @@ def __init__(self, request: Request) -> None: self.MAX_INCLUDE_DEPTH: int = self.config.get("MAX_INCLUDE_DEPTH", 3) self.headers: HeadersQueryStringManager = HeadersQueryStringManager(**dict(self.request.headers)) - def _get_key_values(self, name: str) -> Dict[str, Union[List[str], str]]: + def _extract_item_key(self, key: str) -> str: + try: + key_start = key.index("[") + 1 + key_end = key.index("]") + return key[key_start:key_end] + except Exception: + msg = "Parse error" + raise BadRequest(msg, parameter=key) + + def _get_unique_key_values(self, name: str) -> Dict[str, str]: """ Return a dict containing key / values items for a given key, used for items like filters, page, etc. @@ -97,25 +107,28 @@ def _get_key_values(self, name: str) -> Dict[str, Union[List[str], str]]: :return: a dict of key / values items :raises BadRequest: if an error occurred while parsing the querystring. """ - results: Dict[str, Union[List[str], str]] = {} + results = {} for raw_key, value in self.qs.multi_items(): key = unquote(raw_key) - try: - if not key.startswith(name): - continue + if not key.startswith(name): + continue - key_start = key.index("[") + 1 - key_end = key.index("]") - item_key = key[key_start:key_end] + item_key = self._extract_item_key(key) + results[item_key] = value - if "," in value: - results.update({item_key: value.split(",")}) - else: - results.update({item_key: value}) - except Exception: - msg = "Parse error" - raise BadRequest(msg, parameter=key) + return results + + def _get_multiple_key_values(self, name: str) -> Dict[str, List]: + results = defaultdict(list) + + for raw_key, value in self.qs.multi_items(): + key = unquote(raw_key) + if not key.startswith(name): + continue + + item_key = self._extract_item_key(key) + results[item_key].extend(value.split(",")) return results @@ -134,7 +147,7 @@ def querystring(self) -> Dict[str, str]: return { key: value for (key, value) in self.qs.multi_items() - if key.startswith(self.managed_keys) or self._get_key_values("filter[") + if key.startswith(self.managed_keys) or self._get_unique_key_values("filter[") } @property @@ -159,8 +172,8 @@ def filters(self) -> List[dict]: raise InvalidFilters(msg) results.extend(loaded_filters) - if self._get_key_values("filter["): - results.extend(self._simple_filters(self._get_key_values("filter["))) + if filter_key_values := self._get_unique_key_values("filter["): + results.extend(self._simple_filters(filter_key_values)) return results @cached_property @@ -186,7 +199,7 @@ def pagination(self) -> PaginationQueryStringManager: :raises BadRequest: if the client is not allowed to disable pagination. """ # check values type - pagination_data: Dict[str, Union[List[str], str]] = self._get_key_values("page") + pagination_data: Dict[str, str] = self._get_unique_key_values("page") pagination = PaginationQueryStringManager(**pagination_data) if pagination_data.get("size") is None: pagination.size = None @@ -199,8 +212,6 @@ def pagination(self) -> PaginationQueryStringManager: return pagination - # TODO: finally use this! upgrade Sqlachemy Data Layer - # and add to all views (get list/detail, create, patch) @property def fields(self) -> Dict[str, List[str]]: """ @@ -216,26 +227,32 @@ def fields(self) -> Dict[str, List[str]]: :raises InvalidField: if result field not in schema. """ - if self.request.method != "GET": - msg = "attribute 'fields' allowed only for GET-method" - raise InvalidField(msg) - fields = self._get_key_values("fields") - for key, value in fields.items(): - if not isinstance(value, list): - value = [value] # noqa: PLW2901 - fields[key] = value + fields = self._get_multiple_key_values("fields") + for resource_type, field_names in fields.items(): # TODO: we have registry for models (BaseModel) # TODO: create `type to schemas` registry - schema: Type[BaseModel] = get_schema_from_type(key, self.app) - for field in value: - if field not in schema.__fields__: + + if resource_type not in RoutersJSONAPI.all_jsonapi_routers: + msg = f"Application has no resource with type {resource_type!r}" + raise InvalidType(msg) + + schema: Type[BaseModel] = self._get_schema(resource_type) + + for field_name in field_names: + if field_name == "": + continue + + if field_name not in schema.__fields__: msg = "{schema} has no attribute {field}".format( schema=schema.__name__, - field=field, + field=field_name, ) raise InvalidField(msg) - return fields + return {resource_type: set(field_names) for resource_type, field_names in fields.items()} + + def _get_schema(self, resource_type: str) -> Type[BaseModel]: + return RoutersJSONAPI.all_jsonapi_routers[resource_type]._schema def get_sorts(self, schema: Type["TypeSchema"]) -> List[Dict[str, str]]: """ diff --git a/fastapi_jsonapi/views/detail_view.py b/fastapi_jsonapi/views/detail_view.py index d80353d5..712f2c17 100644 --- a/fastapi_jsonapi/views/detail_view.py +++ b/fastapi_jsonapi/views/detail_view.py @@ -12,6 +12,7 @@ BaseJSONAPIItemInSchema, JSONAPIResultDetailSchema, ) +from fastapi_jsonapi.views.utils import handle_jsonapi_fields from fastapi_jsonapi.views.view_base import ViewBase if TYPE_CHECKING: @@ -34,22 +35,24 @@ async def handle_get_resource_detail( self, object_id: Union[int, str], **extra_view_deps, - ): + ) -> Union[JSONAPIResultDetailSchema, Dict]: dl: "BaseDataLayer" = await self.get_data_layer(extra_view_deps) view_kwargs = {dl.url_id_field: object_id} db_object = await dl.get_object(view_kwargs=view_kwargs, qs=self.query_params) - return self._build_detail_response(db_object) + response = self._build_detail_response(db_object) + return handle_jsonapi_fields(response, self.query_params, self.jsonapi) async def handle_update_resource( self, obj_id: str, data_update: BaseJSONAPIItemInSchema, **extra_view_deps, - ) -> JSONAPIResultDetailSchema: + ) -> Union[JSONAPIResultDetailSchema, Dict]: dl: "BaseDataLayer" = await self.get_data_layer(extra_view_deps) - return await self.process_update_object(dl=dl, obj_id=obj_id, data_update=data_update) + response = await self.process_update_object(dl=dl, obj_id=obj_id, data_update=data_update) + return handle_jsonapi_fields(response, self.query_params, self.jsonapi) async def process_update_object( self, diff --git a/fastapi_jsonapi/views/list_view.py b/fastapi_jsonapi/views/list_view.py index e7be5421..e6fc59a1 100644 --- a/fastapi_jsonapi/views/list_view.py +++ b/fastapi_jsonapi/views/list_view.py @@ -1,11 +1,12 @@ import logging -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Union from fastapi_jsonapi.schema import ( BaseJSONAPIItemInSchema, JSONAPIResultDetailSchema, JSONAPIResultListSchema, ) +from fastapi_jsonapi.views.utils import handle_jsonapi_fields from fastapi_jsonapi.views.view_base import ViewBase if TYPE_CHECKING: @@ -34,21 +35,23 @@ async def get_data_layer( ) -> "BaseDataLayer": return await self.get_data_layer_for_list(extra_view_deps) - async def handle_get_resource_list(self, **extra_view_deps) -> JSONAPIResultListSchema: + async def handle_get_resource_list(self, **extra_view_deps) -> Union[JSONAPIResultListSchema, Dict]: dl: "BaseDataLayer" = await self.get_data_layer(extra_view_deps) query_params = self.query_params count, items_from_db = await dl.get_collection(qs=query_params) total_pages = self._calculate_total_pages(count) - return self._build_list_response(items_from_db, count, total_pages) + response = self._build_list_response(items_from_db, count, total_pages) + return handle_jsonapi_fields(response, query_params, self.jsonapi) async def handle_post_resource_list( self, data_create: BaseJSONAPIItemInSchema, **extra_view_deps, - ) -> JSONAPIResultDetailSchema: + ) -> Union[JSONAPIResultDetailSchema, Dict]: dl: "BaseDataLayer" = await self.get_data_layer(extra_view_deps) - return await self.process_create_object(dl=dl, data_create=data_create) + response = await self.process_create_object(dl=dl, data_create=data_create) + return handle_jsonapi_fields(response, self.query_params, self.jsonapi) async def process_create_object(self, dl: "BaseDataLayer", data_create: BaseJSONAPIItemInSchema): created_object = await dl.create_object(data_create=data_create, view_kwargs={}) @@ -68,4 +71,5 @@ async def handle_delete_resource_list(self, **extra_view_deps) -> JSONAPIResultL await dl.delete_objects(items_from_db, {}) - return self._build_list_response(items_from_db, count, total_pages) + response = self._build_list_response(items_from_db, count, total_pages) + return handle_jsonapi_fields(response, self.query_params, self.jsonapi) diff --git a/fastapi_jsonapi/views/utils.py b/fastapi_jsonapi/views/utils.py index 5f80af1a..e521d773 100644 --- a/fastapi_jsonapi/views/utils.py +++ b/fastapi_jsonapi/views/utils.py @@ -1,8 +1,39 @@ +from __future__ import annotations + +from collections import defaultdict from enum import Enum from functools import cache -from typing import Callable, Coroutine, Optional, Set, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Dict, + Iterable, + List, + Optional, + Set, + Type, + Union, +) from pydantic import BaseModel +from pydantic.fields import ModelField + +from fastapi_jsonapi.data_typing import TypeSchema +from fastapi_jsonapi.schema import JSONAPIObjectSchema +from fastapi_jsonapi.schema_builder import ( + JSONAPIResultDetailSchema, + JSONAPIResultListSchema, +) + +if TYPE_CHECKING: + from fastapi_jsonapi.api import RoutersJSONAPI + from fastapi_jsonapi.querystring import QueryStringManager + + +JSONAPIResponse = Union[JSONAPIResultDetailSchema, JSONAPIResultListSchema] +IGNORE_ALL_FIELDS_LITERAL = "" class HTTPMethod(Enum): @@ -27,3 +58,99 @@ class Config: @property def handler(self) -> Optional[Union[Callable, Coroutine]]: return self.prepare_data_layer_kwargs + + +def _get_includes_indexes_by_type(included: List[JSONAPIObjectSchema]) -> Dict[str, List[int]]: + result = defaultdict(list) + + for idx, item in enumerate(included): + result[item.type].append(idx) + + return result + + +# TODO: move to schema builder? +def _is_relationship_field(field: ModelField) -> bool: + return "relationship" in field.field_info.extra + + +def _get_schema_field_names(schema: Type[TypeSchema]) -> Set[str]: + """ + Returns all attribute names except relationships + """ + result = set() + + for field_name, field in schema.__fields__.items(): + if _is_relationship_field(field): + continue + + result.add(field_name) + + return result + + +def _get_exclude_fields( + schema: Type[TypeSchema], + include_fields: Iterable[str], +) -> Set[str]: + schema_fields = _get_schema_field_names(schema) + + if IGNORE_ALL_FIELDS_LITERAL in include_fields: + return schema_fields + + return set(_get_schema_field_names(schema)).difference(include_fields) + + +def _calculate_exclude_fields( + response: JSONAPIResponse, + query_params: QueryStringManager, + jsonapi: RoutersJSONAPI, +) -> Dict: + included = "included" in response.__fields__ and response.included or [] + is_list_response = isinstance(response, JSONAPIResultListSchema) + + exclude_params: Dict[str, Any] = {} + + includes_indexes_by_type = _get_includes_indexes_by_type(included) + + for resource_type, field_names in query_params.fields.items(): + schema = jsonapi.all_jsonapi_routers[resource_type]._schema + exclude_fields = _get_exclude_fields(schema, include_fields=field_names) + attributes_exclude = {"attributes": exclude_fields} + + if resource_type == jsonapi.type_: + if is_list_response: + exclude_params["data"] = {"__all__": attributes_exclude} + else: + exclude_params["data"] = attributes_exclude + + continue + + if not included: + continue + + target_type_indexes = includes_indexes_by_type.get(resource_type) + + if target_type_indexes: + if "included" not in exclude_params: + exclude_params["included"] = {} + + exclude_params["included"].update((idx, attributes_exclude) for idx in target_type_indexes) + + return exclude_params + + +def handle_jsonapi_fields( + response: JSONAPIResponse, + query_params: QueryStringManager, + jsonapi: RoutersJSONAPI, +) -> Union[JSONAPIResponse, Dict]: + if not query_params.fields: + return response + + exclude_params = _calculate_exclude_fields(response, query_params, jsonapi) + + if exclude_params: + return response.dict(exclude=exclude_params, by_alias=True) + + return response diff --git a/fastapi_jsonapi/views/view_base.py b/fastapi_jsonapi/views/view_base.py index 9ef8dee7..5566d7ab 100644 --- a/fastapi_jsonapi/views/view_base.py +++ b/fastapi_jsonapi/views/view_base.py @@ -30,6 +30,7 @@ from fastapi_jsonapi.schema import ( JSONAPIObjectSchema, JSONAPIResultListMetaSchema, + JSONAPIResultListSchema, get_related_schema, ) from fastapi_jsonapi.schema_base import BaseModel, RelationshipInfo @@ -185,7 +186,12 @@ def _build_detail_response(self, db_item: TypeModel): return detail_jsonapi_schema(data=result_object, **extras) - def _build_list_response(self, items_from_db: List[TypeModel], count: int, total_pages: int): + def _build_list_response( + self, + items_from_db: List[TypeModel], + count: int, + total_pages: int, + ) -> JSONAPIResultListSchema: result_objects, object_schemas, extras = self._build_response(items_from_db, self.jsonapi.schema_list) # we need to build a new schema here diff --git a/tests/fixtures/app.py b/tests/fixtures/app.py index 1bf87887..b0a68075 100644 --- a/tests/fixtures/app.py +++ b/tests/fixtures/app.py @@ -112,7 +112,7 @@ def add_routers(app_plain: FastAPI): class_detail=DetailViewBaseGeneric, class_list=ListViewBaseGeneric, schema=PostCommentSchema, - resource_type="comment", + resource_type="post_comment", model=PostComment, ) diff --git a/tests/fixtures/entities.py b/tests/fixtures/entities.py index 23eb3547..f5f45970 100644 --- a/tests/fixtures/entities.py +++ b/tests/fixtures/entities.py @@ -99,20 +99,31 @@ async def user_2_bio(async_session: AsyncSession, user_2: User) -> UserBio: ) +async def build_post(async_session: AsyncSession, user: User, **fields) -> Post: + fields = {"title": fake.name(), "body": fake.sentence(), **fields} + post = Post(user=user, **fields) + async_session.add(post) + await async_session.commit() + return post + + @async_fixture() -async def user_1_posts(async_session: AsyncSession, user_1: User): - posts = [Post(title=f"post_u1_{i}", user=user_1) for i in range(1, 4)] +async def user_1_posts(async_session: AsyncSession, user_1: User) -> List[Post]: + posts = [ + Post( + title=f"post_u1_{i}", + user=user_1, + body=fake.sentence(), + ) + for i in range(1, 4) + ] async_session.add_all(posts) await async_session.commit() for post in posts: await async_session.refresh(post) - yield posts - - for post in posts: - await async_session.delete(post) - await async_session.commit() + return posts @async_fixture() @@ -130,19 +141,22 @@ async def user_1_post(async_session: AsyncSession, user_1: User): @async_fixture() -async def user_2_posts(async_session: AsyncSession, user_2: User): - posts = [Post(title=f"post_u2_{i}", user=user_2) for i in range(1, 5)] +async def user_2_posts(async_session: AsyncSession, user_2: User) -> List[Post]: + posts = [ + Post( + title=f"post_u2_{i}", + user=user_2, + body=fake.sentence(), + ) + for i in range(1, 5) + ] async_session.add_all(posts) await async_session.commit() for post in posts: await async_session.refresh(post) - yield posts - - for post in posts: - await async_session.delete(post) - await async_session.commit() + return posts @async_fixture() @@ -213,6 +227,23 @@ async def factory(name: str | None = None) -> Computer: return factory +async def build_post_comment( + async_session: AsyncSession, + user: User, + post: Post, + **fields, +) -> PostComment: + fields = {"text": fake.sentence(), **fields} + post_comment = PostComment( + author=user, + post=post, + **fields, + ) + async_session.add(post_comment) + await async_session.commit() + return post_comment + + @async_fixture() async def user_2_comment_for_one_u1_post(async_session: AsyncSession, user_2, user_1_post_for_comments): post = user_1_post_for_comments diff --git a/tests/test_api/test_api_sqla_with_includes.py b/tests/test_api/test_api_sqla_with_includes.py index d3c31d92..285db092 100644 --- a/tests/test_api/test_api_sqla_with_includes.py +++ b/tests/test_api/test_api_sqla_with_includes.py @@ -4,7 +4,7 @@ from datetime import datetime, timezone from itertools import chain, zip_longest from json import dumps, loads -from typing import Dict, List, Literal +from typing import Dict, List, Literal, Set, Tuple from uuid import UUID, uuid4 import pytest @@ -16,11 +16,17 @@ from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import InstrumentedAttribute +from starlette.datastructures import QueryParams from fastapi_jsonapi.views.view_base import ViewBase from tests.common import is_postgres_tests from tests.fixtures.app import build_alphabet_app, build_app_custom -from tests.fixtures.entities import build_workplace, create_user +from tests.fixtures.entities import ( + build_post, + build_post_comment, + build_workplace, + create_user, +) from tests.misc.utils import fake from tests.models import ( Alpha, @@ -151,6 +157,204 @@ async def test_get_users_paginated( "meta": {"count": 2, "totalPages": 2}, } + @mark.parametrize( + "fields, expected_include", + [ + param( + [ + ("fields[user]", "name,age"), + ], + {"name", "age"}, + ), + param( + [ + ("fields[user]", "name,age"), + ("fields[user]", "email"), + ], + {"name", "age", "email"}, + ), + ], + ) + async def test_select_custom_fields( + self, + app: FastAPI, + client: AsyncClient, + user_1: User, + user_2: User, + fields: List[Tuple[str, str]], + expected_include: Set[str], + ): + url = app.url_path_for("get_user_list") + user_1, user_2 = sorted((user_1, user_2), key=lambda x: x.id) + + params = QueryParams(fields) + response = await client.get(url, params=str(params)) + + assert response.status_code == status.HTTP_200_OK, response.text + response_data = response.json() + + assert response_data == { + "data": [ + { + "attributes": UserAttributesBaseSchema.from_orm(user_1).dict(include=expected_include), + "id": str(user_1.id), + "type": "user", + }, + { + "attributes": UserAttributesBaseSchema.from_orm(user_2).dict(include=expected_include), + "id": str(user_2.id), + "type": "user", + }, + ], + "jsonapi": {"version": "1.0"}, + "meta": {"count": 2, "totalPages": 1}, + } + + async def test_select_custom_fields_with_includes( + self, + app: FastAPI, + async_session: AsyncSession, + client: AsyncClient, + user_1: User, + user_2: User, + ): + url = app.url_path_for("get_user_list") + user_1, user_2 = sorted((user_1, user_2), key=lambda x: x.id) + + user_2_post = await build_post(async_session, user_2) + user_1_post = await build_post(async_session, user_1) + + user_1_comment = await build_post_comment(async_session, user_1, user_2_post) + user_2_comment = await build_post_comment(async_session, user_2, user_1_post) + + queried_user_fields = "name" + queried_post_fields = "title" + + params = QueryParams( + [ + ("fields[user]", queried_user_fields), + ("fields[post]", queried_post_fields), + # empty str means ignore all fields + ("fields[post_comment]", ""), + ("include", "posts,posts.comments"), + ("sort", "id"), + ], + ) + response = await client.get(url, params=str(params)) + + assert response.status_code == status.HTTP_200_OK, response.text + response_data = response.json() + response_data["included"] = sorted(response_data["included"], key=lambda x: (x["type"], x["id"])) + + assert response_data == { + "data": [ + { + "attributes": UserAttributesBaseSchema.from_orm(user_1).dict( + include=set(queried_user_fields.split(",")), + ), + "relationships": { + "posts": { + "data": [ + { + "id": str(user_1_post.id), + "type": "post", + }, + ], + }, + }, + "id": str(user_1.id), + "type": "user", + }, + { + "attributes": UserAttributesBaseSchema.from_orm(user_2).dict( + include=set(queried_user_fields.split(",")), + ), + "relationships": { + "posts": { + "data": [ + { + "id": str(user_2_post.id), + "type": "post", + }, + ], + }, + }, + "id": str(user_2.id), + "type": "user", + }, + ], + "jsonapi": {"version": "1.0"}, + "meta": {"count": 2, "totalPages": 1}, + "included": sorted( + [ + { + "attributes": PostAttributesBaseSchema.from_orm(user_2_post).dict( + include=set(queried_post_fields.split(",")), + ), + "id": str(user_2_post.id), + "relationships": { + "comments": { + "data": [ + { + "id": str(user_1_comment.id), + "type": "post_comment", + }, + ], + }, + }, + "type": "post", + }, + { + "attributes": PostAttributesBaseSchema.from_orm(user_1_post).dict( + include=set(queried_post_fields.split(",")), + ), + "id": str(user_1_post.id), + "relationships": { + "comments": {"data": [{"id": str(user_2_comment.id), "type": "post_comment"}]}, + }, + "type": "post", + }, + { + "attributes": {}, + "id": str(user_1_comment.id), + "type": "post_comment", + }, + { + "attributes": {}, + "id": str(user_2_comment.id), + "type": "post_comment", + }, + ], + key=lambda x: (x["type"], x["id"]), + ), + } + + async def test_select_custom_fields_for_includes_without_requesting_includes( + self, + app: FastAPI, + client: AsyncClient, + user_1: User, + ): + url = app.url_path_for("get_user_list") + + params = QueryParams([("fields[post]", "title")]) + response = await client.get(url, params=str(params)) + + assert response.status_code == status.HTTP_200_OK, response.text + response_data = response.json() + + assert response_data == { + "data": [ + { + "attributes": UserAttributesBaseSchema.from_orm(user_1), + "id": str(user_1.id), + "type": "user", + }, + ], + "jsonapi": {"version": "1.0"}, + "meta": {"count": 1, "totalPages": 1}, + } + class TestCreatePostAndComments: async def test_get_posts_with_users( @@ -266,7 +470,7 @@ async def test_create_comments_for_post( user_2: User, user_1_post: Post, ): - url = app.url_path_for("get_comment_list") + url = app.url_path_for("get_post_comment_list") url = f"{url}?include=author,post,post.user" comment_attributes = PostCommentAttributesBaseSchema( text=fake.sentence(), @@ -297,7 +501,7 @@ async def test_create_comments_for_post( comment_id = comment_data.pop("id") assert comment_id assert comment_data == { - "type": "comment", + "type": "post_comment", "attributes": comment_attributes, "relationships": { "post": { @@ -355,7 +559,7 @@ async def test_create_comment_error_no_relationship( :param user_1_post: :return: """ - url = app.url_path_for("get_comment_list") + url = app.url_path_for("get_post_comment_list") comment_attributes = PostCommentAttributesBaseSchema( text=fake.sentence(), ).dict() @@ -396,7 +600,7 @@ async def test_create_comment_error_no_relationships_content( app: FastAPI, client: AsyncClient, ): - url = app.url_path_for("get_comment_list") + url = app.url_path_for("get_post_comment_list") comment_attributes = PostCommentAttributesBaseSchema( text=fake.sentence(), ).dict() @@ -442,7 +646,7 @@ async def test_create_comment_error_no_relationships_field( app: FastAPI, client: AsyncClient, ): - url = app.url_path_for("get_comment_list") + url = app.url_path_for("get_post_comment_list") comment_attributes = PostCommentAttributesBaseSchema( text=fake.sentence(), ).dict() @@ -613,6 +817,35 @@ async def test_many_to_many_load_inner_includes_to_parents( assert ("child", ViewBase.get_db_item_id(child_4)) not in included_data +class TestGetUserDetail: + def get_url(self, app: FastAPI, user_id: int) -> str: + return app.url_path_for("get_user_detail", obj_id=user_id) + + async def test_select_custom_fields( + self, + app: FastAPI, + client: AsyncClient, + user_1: User, + ): + url = self.get_url(app, user_1.id) + queried_user_fields = "name,age" + params = QueryParams([("fields[user]", queried_user_fields)]) + response = await client.get(url, params=params) + + assert response.status_code == status.HTTP_200_OK + assert response.json() == { + "data": { + "attributes": UserAttributesBaseSchema.from_orm(user_1).dict( + include=set(queried_user_fields.split(",")), + ), + "id": str(user_1.id), + "type": "user", + }, + "jsonapi": {"version": "1.0"}, + "meta": None, + } + + class TestUserWithPostsWithInnerIncludes: @mark.parametrize( "include, expected_relationships_inner_relations, expect_user_include", @@ -1336,6 +1569,35 @@ class ContainsTimestampAttrsSchema(BaseModel): "data": [], } + async def test_select_custom_fields(self, app: FastAPI, client: AsyncClient): + user_attrs_schema = UserAttributesBaseSchema( + name=fake.name(), + age=fake.pyint(), + email=fake.email(), + ) + create_user_body = { + "data": { + "attributes": user_attrs_schema.dict(), + }, + } + queried_user_fields = "name" + params = QueryParams([("fields[user]", queried_user_fields)]) + url = app.url_path_for("get_user_list") + res = await client.post(url, json=create_user_body, params=params) + assert res.status_code == status.HTTP_201_CREATED, res.text + response_data: dict = res.json() + + assert "data" in response_data + assert response_data["data"].pop("id") + assert response_data == { + "data": { + "attributes": user_attrs_schema.dict(include=set(queried_user_fields.split(","))), + "type": "user", + }, + "jsonapi": {"version": "1.0"}, + "meta": None, + } + class TestPatchObjects: async def test_patch_object( @@ -1448,6 +1710,40 @@ async def test_update_schema_has_extra_fields(self, user_1: User, caplog): ): assert expected in log_message + async def test_select_custom_fields( + self, + app: FastAPI, + client: AsyncClient, + user_1: User, + ): + new_attrs = UserAttributesBaseSchema( + name=fake.name(), + age=fake.pyint(), + email=fake.email(), + ) + + patch_user_body = { + "data": { + "id": user_1.id, + "attributes": new_attrs.dict(), + }, + } + queried_user_fields = "name" + params = QueryParams([("fields[user]", queried_user_fields)]) + url = app.url_path_for("get_user_detail", obj_id=user_1.id) + res = await client.patch(url, params=params, json=patch_user_body) + + assert res.status_code == status.HTTP_200_OK, res.text + assert res.json() == { + "data": { + "attributes": new_attrs.dict(include=set(queried_user_fields.split(","))), + "id": str(user_1.id), + "type": "user", + }, + "jsonapi": {"version": "1.0"}, + "meta": None, + } + class TestPatchObjectRelationshipsToOne: async def test_ok_when_foreign_key_of_related_object_is_nullable( @@ -1918,6 +2214,39 @@ async def test_delete_objects_many( "meta": {"count": 1, "totalPages": 1}, } + async def test_select_custom_fields( + self, + app: FastAPI, + client: AsyncClient, + user_1: User, + user_2: User, + ): + queried_user_fields = "name" + params = QueryParams([("fields[user]", queried_user_fields)]) + url = app.url_path_for("get_user_list") + res = await client.delete(url, params=params) + assert res.status_code == status.HTTP_200_OK, res.text + assert res.json() == { + "data": [ + { + "attributes": UserAttributesBaseSchema.from_orm(user_1).dict( + include=set(queried_user_fields.split(",")), + ), + "id": str(user_1.id), + "type": "user", + }, + { + "attributes": UserAttributesBaseSchema.from_orm(user_2).dict( + include=set(queried_user_fields.split(",")), + ), + "id": str(user_2.id), + "type": "user", + }, + ], + "jsonapi": {"version": "1.0"}, + "meta": {"count": 2, "totalPages": 1}, + } + class TestOpenApi: def test_openapi_method_ok(self, app: FastAPI): diff --git a/tests/test_fastapi_jsonapi/__init__.py b/tests/test_fastapi_jsonapi/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_fastapi_jsonapi/test_querystring.py b/tests/test_fastapi_jsonapi/test_querystring.py new file mode 100644 index 00000000..dec04365 --- /dev/null +++ b/tests/test_fastapi_jsonapi/test_querystring.py @@ -0,0 +1,92 @@ +import json +from unittest.mock import MagicMock + +import pytest +from fastapi import status +from starlette.datastructures import QueryParams + +from fastapi_jsonapi.exceptions import InvalidFilters +from fastapi_jsonapi.exceptions.json_api import BadRequest +from fastapi_jsonapi.querystring import QueryStringManager + + +def test__extract_item_key(): + manager = QueryStringManager(MagicMock()) + + key = "fields[user]" + assert manager._extract_item_key(key) == "user" + + with pytest.raises(BadRequest) as exc_info: # noqa: PT012 + key = "fields[user" + manager._extract_item_key(key) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == { + "errors": [ + { + "status_code": status.HTTP_400_BAD_REQUEST, + "source": {"parameter": "fields[user"}, + "title": "Bad Request", + "detail": "Parse error", + }, + ], + } + + +def test_querystring(): + request = MagicMock() + request.query_params = QueryParams([("fields[user]", "name")]) + manager = QueryStringManager(request) + assert manager.querystring == {"fields[user]": "name"} + + +def test_filters__errors(): + request = MagicMock() + request.query_params = QueryParams([("filter", "not_json")]) + manager = QueryStringManager(request) + + with pytest.raises(InvalidFilters) as exc_info: + manager.filters + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == { + "errors": [ + { + "status_code": status.HTTP_400_BAD_REQUEST, + "source": {"parameter": "filters"}, + "title": "Invalid filters querystring parameter.", + "detail": "Parse error", + }, + ], + } + + request.query_params = QueryParams( + [ + ( + "filter", + json.dumps( + { + "name": "", + "op": "", + "val": "", + }, + ), + ), + ], + ) + manager = QueryStringManager(request) + + with pytest.raises(InvalidFilters) as exc_info: + manager.filters + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == { + "errors": [ + { + "status_code": status.HTTP_400_BAD_REQUEST, + "source": {"parameter": "filters"}, + "title": "Invalid filters querystring parameter.", + "detail": "Incorrect filters format, expected list of conditions but got dict", + }, + ], + }