From 37e711fc836e1efd1d2e0053985f8f39d12d4cc9 Mon Sep 17 00:00:00 2001 From: Kostis Triantafyllakis Date: Mon, 23 Jan 2023 11:51:24 +0200 Subject: [PATCH] Various client authentication related fixes --- src/idpyoidc/server/client_authn.py | 59 ++++++++++--------- src/idpyoidc/server/endpoint.py | 6 +- src/idpyoidc/server/oidc/session.py | 2 +- src/idpyoidc/server/oidc/userinfo.py | 2 +- tests/test_server_17_client_authn.py | 15 +++-- tests/test_server_20d_client_authn.py | 15 +++-- ...st_server_23_oidc_registration_endpoint.py | 2 +- .../test_server_32_oidc_read_registration.py | 2 +- tests/test_server_60_dpop.py | 6 +- 9 files changed, 58 insertions(+), 51 deletions(-) diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index 1c62b556..cd3e2c2d 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -262,6 +262,7 @@ def _verify( request: Optional[Union[dict, Message]] = None, authorization_token: Optional[str] = None, endpoint=None, # Optional[Endpoint] + get_client_id_from_token=None, **kwargs, ): _token = request.get("access_token") @@ -269,7 +270,7 @@ def _verify( raise ClientAuthenticationError("No access token") res = {"token": _token} - _client_id = request.get("client_id") + _client_id = get_client_id_from_token(endpoint_context, _token, request) if _client_id: res["client_id"] = _client_id return res @@ -483,6 +484,7 @@ def verify_client( auth_info = {} methods = endpoint_context.client_authn_method + client_id = None allowed_methods = getattr(endpoint, "client_authn_method") if not allowed_methods: allowed_methods = list(methods.keys()) @@ -499,48 +501,47 @@ def verify_client( endpoint=endpoint, get_client_id_from_token=get_client_id_from_token, ) - break except (BearerTokenAuthenticationError, ClientAuthenticationError): raise except Exception as err: logger.info("Verifying auth using {} failed: {}".format(_method.tag, err)) + continue - if auth_info.get("method") == "none": - return auth_info + if auth_info.get("method") == "none" and auth_info.get("client_id") is None: + break - client_id = auth_info.get("client_id") - if client_id is None: - raise ClientAuthenticationError("Failed to verify client") + client_id = auth_info.get("client_id") + if client_id is None: + raise ClientAuthenticationError("Failed to verify client") - if also_known_as: - client_id = also_known_as[client_id] - auth_info["client_id"] = client_id + if also_known_as: + client_id = also_known_as[client_id] + auth_info["client_id"] = client_id - if client_id not in endpoint_context.cdb: - raise UnknownClient("Unknown Client ID") + if client_id not in endpoint_context.cdb: + raise UnknownClient("Unknown Client ID") - _cinfo = endpoint_context.cdb[client_id] + _cinfo = endpoint_context.cdb[client_id] - if not valid_client_info(_cinfo): - logger.warning("Client registration has timed out or " "client secret is expired.") - raise InvalidClient("Not valid client") + if not valid_client_info(_cinfo): + logger.warning("Client registration has timed out or " "client secret is expired.") + raise InvalidClient("Not valid client") - # Validate that the used method is allowed for this client/endpoint - client_allowed_methods = _cinfo.get( - f"{endpoint.endpoint_name}_client_authn_method", _cinfo.get("client_authn_method") - ) - if client_allowed_methods is not None and _method and _method.tag not in client_allowed_methods: - logger.info( - f"Allowed methods for client: {client_id} at endpoint: {endpoint.name} are: " - f"`{', '.join(client_allowed_methods)}`" - ) - raise UnAuthorizedClient( - f"Authentication method: {_method.tag} not allowed for client: {client_id} in " - f"endpoint: {endpoint.name}" + # Validate that the used method is allowed for this client/endpoint + client_allowed_methods = _cinfo.get( + f"{endpoint.endpoint_name}_client_authn_method", _cinfo.get("client_authn_method") ) + if client_allowed_methods is not None and auth_info["method"] not in client_allowed_methods: + logger.info( + f"Allowed methods for client: {client_id} at endpoint: {endpoint.name} are: " + f"`{', '.join(client_allowed_methods)}`" + ) + auth_info = {} + continue + break # store what authn method was used - if auth_info.get("method"): + if "method" in auth_info and client_id: _request_type = request.__class__.__name__ _used_authn_method = _cinfo.get("auth_method") if _used_authn_method: diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index 24e7b3f6..2167285f 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -132,6 +132,9 @@ def set_client_authn_methods(self, **kwargs): kwargs[self.auth_method_attribute] = _methods elif _methods is not None: # [] or '' or something not None but regarded as nothing. self.client_authn_method = ["none"] # Ignore default value + elif self.default_capabilities: + self.client_authn_method = self.default_capabilities.get("client_authn_method") + self.endpoint_info = construct_provider_info(self.default_capabilities, **kwargs) return kwargs def get_provider_info_attributes(self): @@ -249,7 +252,8 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No if authn_info == {} and self.client_authn_method and len(self.client_authn_method): LOGGER.debug("client_authn_method: %s", self.client_authn_method) raise UnAuthorizedClient("Authorization failed") - + if "client_id" not in authn_info and authn_info.get("method") != "none": + raise UnAuthorizedClient("Authorization failed") return authn_info def do_post_parse_request( diff --git a/src/idpyoidc/server/oidc/session.py b/src/idpyoidc/server/oidc/session.py index 5768cf5b..716b8277 100644 --- a/src/idpyoidc/server/oidc/session.py +++ b/src/idpyoidc/server/oidc/session.py @@ -361,7 +361,7 @@ def parse_request(self, request, http_info=None, **kwargs): # Verify that the client is allowed to do this auth_info = self.client_authentication(request, http_info, **kwargs) - if not auth_info or auth_info["method"] == "none": + if not auth_info: pass elif isinstance(auth_info, ResponseMessage): return auth_info diff --git a/src/idpyoidc/server/oidc/userinfo.py b/src/idpyoidc/server/oidc/userinfo.py index 6b5473d0..ae6e87b5 100755 --- a/src/idpyoidc/server/oidc/userinfo.py +++ b/src/idpyoidc/server/oidc/userinfo.py @@ -182,7 +182,7 @@ def parse_request(self, request, http_info=None, **kwargs): try: auth_info = self.client_authentication(request, http_info, **kwargs) except ClientAuthenticationError as e: - return self.error_cls(error="invalid_token", error_description=e.args[0]) + return self.error_cls(error="invalid_token", error_description="Invalid token") if isinstance(auth_info, ResponseMessage): return auth_info diff --git a/tests/test_server_17_client_authn.py b/tests/test_server_17_client_authn.py index 4575ecd8..d42a2325 100644 --- a/tests/test_server_17_client_authn.py +++ b/tests/test_server_17_client_authn.py @@ -337,7 +337,7 @@ def create_method(self): def test_bearer_body(self): request = {"access_token": "1234567890"} - assert self.method.verify(request) == {"token": "1234567890", "method": "bearer_body"} + assert self.method.verify(request, get_client_id_from_token=get_client_id_from_token) == {"token": "1234567890", "method": "bearer_body"} def test_bearer_body_no_token(self): request = {} @@ -504,13 +504,12 @@ def test_verify_per_client_per_endpoint(self): ) assert res == {"method": "public", "client_id": client_id} - with pytest.raises(ClientAuthenticationError) as e: - verify_client( - self.endpoint_context, - request, - endpoint=self.server.server_get("endpoint", "endpoint_1"), - ) - assert e.value.args[0] == "Failed to verify client" + res = verify_client( + self.endpoint_context, + request, + endpoint=self.server.server_get("endpoint", "endpoint_1"), + ) + assert res == {} request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( diff --git a/tests/test_server_20d_client_authn.py b/tests/test_server_20d_client_authn.py index e81d26dd..55ab886c 100755 --- a/tests/test_server_20d_client_authn.py +++ b/tests/test_server_20d_client_authn.py @@ -292,7 +292,7 @@ def create_method(self): def test_bearer_body(self): request = {"access_token": "1234567890"} - assert self.method.verify(request) == {"token": "1234567890", "method": "bearer_body"} + assert self.method.verify(request, get_client_id_from_token=get_client_id_from_token) == {"token": "1234567890", "method": "bearer_body"} def test_bearer_body_no_token(self): request = {} @@ -457,13 +457,12 @@ def test_verify_per_client_per_endpoint(self): ) assert res == {"method": "public", "client_id": client_id} - with pytest.raises(ClientAuthenticationError) as e: - verify_client( - self.endpoint_context, - request, - endpoint=self.server.server_get("endpoint", "token"), - ) - assert e.value.args[0] == "Failed to verify client" + res = verify_client( + self.endpoint_context, + request, + endpoint=self.server.server_get("endpoint", "token"), + ) + assert res == {} request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( diff --git a/tests/test_server_23_oidc_registration_endpoint.py b/tests/test_server_23_oidc_registration_endpoint.py index 5b2ef4ae..64cb2a1b 100755 --- a/tests/test_server_23_oidc_registration_endpoint.py +++ b/tests/test_server_23_oidc_registration_endpoint.py @@ -127,7 +127,7 @@ def create_endpoint(self): "registration": { "path": "registration", "class": Registration, - "kwargs": {"client_auth_method": None}, + "kwargs": {"client_authn_method": ["none"]}, }, "authorization": { "path": "authorization", diff --git a/tests/test_server_32_oidc_read_registration.py b/tests/test_server_32_oidc_read_registration.py index 1f7670ad..2e803ba7 100644 --- a/tests/test_server_32_oidc_read_registration.py +++ b/tests/test_server_32_oidc_read_registration.py @@ -95,7 +95,7 @@ def create_endpoint(self): "registration": { "path": "registration", "class": Registration, - "kwargs": {"client_auth_method": None}, + "kwargs": {"client_authn_method": ["none"]}, }, "registration_api": { "path": "registration_api", diff --git a/tests/test_server_60_dpop.py b/tests/test_server_60_dpop.py index cd0301ef..7b74e172 100644 --- a/tests/test_server_60_dpop.py +++ b/tests/test_server_60_dpop.py @@ -164,7 +164,11 @@ def create_endpoint(self): "class": Authorization, "kwargs": {}, }, - "token": {"path": "{}/token", "class": Token, "kwargs": {}}, + "token": { + "path": "{}/token", + "class": Token, + "kwargs": {"client_authn_method": ["none"]}, + }, }, "client_authn": verify_client, "authentication": {