diff --git a/setup.py b/setup.py index e7e6385..693eaaf 100755 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ def run_tests(self): author_email="roland@catalogix.se", license="Apache 2.0", url='https://github.com/IdentityPython/oidcmsg/', - packages=["oidcmsg", "oidcmsg/oauth2", "oidcmsg/oidc"], + packages=["oidcmsg", "oidcmsg/oauth2", "oidcmsg/oidc", "oidcmsg/storage"], package_dir={"": "src"}, classifiers=[ "Development Status :: 4 - Beta", diff --git a/src/oidcmsg/__init__.py b/src/oidcmsg/__init__.py index 30a5958..28cfc5f 100755 --- a/src/oidcmsg/__init__.py +++ b/src/oidcmsg/__init__.py @@ -1,5 +1,5 @@ __author__ = "Roland Hedberg" -__version__ = "1.3.3" +__version__ = "1.4.0" import os from typing import Dict diff --git a/src/oidcmsg/impexp.py b/src/oidcmsg/impexp.py index 90d96ee..269916d 100644 --- a/src/oidcmsg/impexp.py +++ b/src/oidcmsg/impexp.py @@ -7,6 +7,11 @@ from cryptojwt.utils import qualified_name from oidcmsg.message import Message +from oidcmsg.storage import DictType + + +def fully_qualified_name(cls): + return cls.__module__ + "." + cls.__class__.__name__ class ImpExp: @@ -23,6 +28,15 @@ def dump_attr(self, cls, item, exclude_attributes: Optional[List[str]] = None) - val = as_bytes(item) else: val = item + elif cls == "DICT_TYPE": + if isinstance(item, dict): + val = item + else: + if isinstance(item, DictType): # item should be a class instance + val = { + "DICT_TYPE": {"class": fully_qualified_name(item), "kwargs": item.kwargs}} + else: + raise ValueError("Expected a DictType class") elif isinstance(item, Message): val = {qualified_name(item.__class__): item.to_dict()} elif cls == object: @@ -83,6 +97,12 @@ def load_attr( val = as_bytes(item) else: val = item + elif cls == "DICT_TYPE": + if list(item.keys()) == ["DICT_TYPE"]: + _spec = item["DICT_TYPE"] + val = importer(_spec["class"])(**_spec["kwargs"]) + else: + val = item elif cls == object: val = importer(item) elif isinstance(cls, list): diff --git a/src/oidcmsg/storage/__init__.py b/src/oidcmsg/storage/__init__.py new file mode 100644 index 0000000..53b3c9b --- /dev/null +++ b/src/oidcmsg/storage/__init__.py @@ -0,0 +1,3 @@ +class DictType(object): + def __init__(self, **kwargs): + self.kwargs = kwargs diff --git a/src/oidcmsg/storage/abfile.py b/src/oidcmsg/storage/abfile.py new file mode 100644 index 0000000..a6eb15e --- /dev/null +++ b/src/oidcmsg/storage/abfile.py @@ -0,0 +1,291 @@ +import logging +import os +import time +from typing import Optional + +from cryptojwt.utils import importer +from filelock import FileLock + +from oidcmsg.storage import DictType +from oidcmsg.util import PassThru +from oidcmsg.util import QPKey + +logger = logging.getLogger(__name__) + + +class AbstractFileSystem(DictType): + """ + FileSystem implements a simple file based database. + It has a dictionary like interface. + Each key maps one-to-one to a file on disc, where the content of the + file is the value. + ONLY goes one level deep. + Not directories in directories. + """ + + def __init__(self, + fdir: Optional[str] = '', + key_conv: Optional[str] = '', + value_conv: Optional[str] = ''): + """ + items = FileSystem( + { + 'fdir': fdir, + 'key_conv':{'to': quote_plus, 'from': unquote_plus}, + 'value_conv':{'to': keyjar_to_jwks, 'from': jwks_to_keyjar} + }) + + :param fdir: The root of the directory + :param key_conv: Converts to/from the key displayed by this class to + users of it to something that can be used as a file name. + The value of key_conv is a class that has the methods 'serialize'/'deserialize'. + :param value_conv: As with key_conv you can convert/translate + the value bound to a key in the database to something that can easily + be stored in a file. Like with key_conv the value of this parameter + is a class that has the methods 'serialize'/'deserialize'. + """ + super(AbstractFileSystem, self).__init__(fdir=fdir, key_conv=key_conv, value_conv=value_conv) + + self.fdir = fdir + self.fmtime = {} + self.storage = {} + + if key_conv: + self.key_conv = importer(key_conv)() + else: + self.key_conv = QPKey() + + if value_conv: + self.value_conv = importer(value_conv)() + else: + self.value_conv = PassThru() + + if not os.path.isdir(self.fdir): + os.makedirs(self.fdir) + + self.synch() + + def get(self, item, default=None): + try: + return self[item] + except KeyError: + return default + + def __getitem__(self, item): + """ + Return the value bound to an identifier. + + :param item: The identifier. + :return: + """ + item = self.key_conv.serialize(item) + + if self.is_changed(item): + logger.info("File content change in {}".format(item)) + fname = os.path.join(self.fdir, item) + self.storage[item] = self._read_info(fname) + + logger.debug('Read from "%s"', item) + return self.storage[item] + + def __setitem__(self, key, value): + """ + Binds a value to a specific key. If the file that the key maps to + does not exist it will be created. The content of the file will be + set to the value given. + + :param key: Identifier + :param value: Value that should be bound to the identifier. + :return: + """ + + if not os.path.isdir(self.fdir): + os.makedirs(self.fdir, exist_ok=True) + + try: + _key = self.key_conv.serialize(key) + except KeyError: + _key = key + + fname = os.path.join(self.fdir, _key) + lock = FileLock('{}.lock'.format(fname)) + with lock: + with open(fname, 'w') as fp: + fp.write(self.value_conv.serialize(value)) + + self.storage[_key] = value + logger.debug('Wrote to "%s"', key) + self.fmtime[_key] = self.get_mtime(fname) + + def __delitem__(self, key): + fname = os.path.join(self.fdir, key) + if os.path.isfile(fname): + lock = FileLock('{}.lock'.format(fname)) + with lock: + os.unlink(fname) + + try: + del self.storage[key] + except KeyError: + pass + + def keys(self): + """ + Implements the dict.keys() method + """ + self.synch() + for k in self.storage.keys(): + yield self.key_conv.deserialize(k) + + @staticmethod + def get_mtime(fname): + """ + Find the time this file was last modified. + + :param fname: File name + :return: The last time the file was modified. + """ + try: + mtime = os.stat(fname).st_mtime_ns + except OSError: + # The file might be right in the middle of being written + # so sleep + time.sleep(1) + mtime = os.stat(fname).st_mtime_ns + + return mtime + + def is_changed(self, item): + """ + Find out if this item has been modified since last + + :param item: A key + :return: True/False + """ + fname = os.path.join(self.fdir, item) + if os.path.isfile(fname): + mtime = self.get_mtime(fname) + + try: + _ftime = self.fmtime[item] + except KeyError: # Never been seen before + self.fmtime[item] = mtime + return True + + if mtime > _ftime: # has changed + self.fmtime[item] = mtime + return True + else: + return False + else: + logger.error('Could not access {}'.format(fname)) + raise KeyError(item) + + def _read_info(self, fname): + if os.path.isfile(fname): + try: + lock = FileLock('{}.lock'.format(fname)) + with lock: + info = open(fname, 'r').read().strip() + return self.value_conv.deserialize(info) + except Exception as err: + logger.error(err) + raise + else: + logger.error('No such file: {}'.format(fname)) + return None + + def synch(self): + """ + Goes through the directory and builds a local cache based on + the content of the directory. + """ + if not os.path.isdir(self.fdir): + os.makedirs(self.fdir) + # raise ValueError('No such directory: {}'.format(self.fdir)) + for f in os.listdir(self.fdir): + fname = os.path.join(self.fdir, f) + + if not os.path.isfile(fname): + continue + if fname.endswith('.lock'): + continue + + if f in self.fmtime: + if self.is_changed(f): + self.storage[f] = self._read_info(fname) + else: + mtime = self.get_mtime(fname) + try: + self.storage[f] = self._read_info(fname) + except Exception as err: + logger.warning('Bad content in {} ({})'.format(fname, err)) + else: + self.fmtime[f] = mtime + + def items(self): + """ + Implements the dict.items() method + """ + self.synch() + for k, v in self.storage.items(): + yield self.key_conv.deserialize(k), v + + def clear(self): + """ + Completely resets the database. This means that all information in + the local cache and on disc will be erased. + """ + if not os.path.isdir(self.fdir): + os.makedirs(self.fdir, exist_ok=True) + return + + for f in os.listdir(self.fdir): + del self[f] + + def update(self, ava): + """ + Replaces what's in the database with a set of key, value pairs. + Only data bound to keys that appear in ava will be affected. + + Implements the dict.update() method + + :param ava: Dictionary + """ + for key, val in ava.items(): + self[key] = val + + def __contains__(self, item): + return self.key_conv.serialize(item) in self.storage + + def __iter__(self): + return self.items() + + def __call__(self, *args, **kwargs): + return [self.key_conv.deserialize(k) for k in self.storage.keys()] + + def __len__(self): + if not os.path.isdir(self.fdir): + return 0 + + n = 0 + for f in os.listdir(self.fdir): + fname = os.path.join(self.fdir, f) + + if not os.path.isfile(fname): + continue + if fname.endswith('.lock'): + continue + + n += 1 + return n + + def __str__(self): + return '{config:' + str(self.config) + ', info:' + str(self.storage) + '}' + + def dump(self): + return {k: v for k, v in self.items()} + + def load(self, info): + for k, v in info.items(): + self[k] = v diff --git a/src/oidcmsg/util.py b/src/oidcmsg/util.py index 3aa6275..da3ba0a 100644 --- a/src/oidcmsg/util.py +++ b/src/oidcmsg/util.py @@ -1,4 +1,7 @@ +import json import secrets +from urllib.parse import quote_plus +from urllib.parse import unquote_plus import yaml @@ -18,3 +21,29 @@ def load_yaml_config(filename): with open(filename, "rt", encoding='utf-8') as file: config_dict = yaml.safe_load(file) return config_dict + + +# Converters + +class QPKey: + def serialize(self, str): + return quote_plus(str) + + def deserialize(self, str): + return unquote_plus(str) + + +class JSON: + def serialize(self, str): + return json.dumps(str) + + def deserialize(self, str): + return json.loads(str) + + +class PassThru: + def serialize(self, str): + return str + + def deserialize(self, str): + return str diff --git a/tests/test_21_abfile.py b/tests/test_21_abfile.py new file mode 100644 index 0000000..eb463eb --- /dev/null +++ b/tests/test_21_abfile.py @@ -0,0 +1,123 @@ +import os +import shutil + +import pytest + +from oidcmsg.impexp import ImpExp +from oidcmsg.storage.abfile import AbstractFileSystem + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +CLIENT_1 = { + "client_secret": 'hemligtkodord', + "redirect_uris": [['https://example.com/cb', '']], + "client_salt": "salted", + 'token_endpoint_auth_method': 'client_secret_post', + 'response_types': ['code', 'token'] +} + +CLIENT_2 = { + "client_secret": "spraket", + "redirect_uris": [['https://app1.example.net/foo', ''], + ['https://app2.example.net/bar', '']], + "response_types": ["code"] +} + + +class ImpExpTest(ImpExp): + parameter = { + "string": "", + "list": [], + "dict": "DICT_TYPE", + } + + +class TestAFS(object): + @pytest.fixture(autouse=True) + def setup(self): + filename = full_path("afs") + if os.path.isdir(filename): + shutil.rmtree(filename) + + def test_create_cdb(self): + abf = AbstractFileSystem(fdir=full_path("afs"), value_conv='oidcmsg.util.JSON') + + # add a client + + abf['client_1'] = CLIENT_1 + + assert list(abf.keys()) == ["client_1"] + + # add another one + + abf['client_2'] = CLIENT_2 + + assert set(abf.keys()) == {"client_1", "client_2"} + + def test_read_cdb(self): + abf = AbstractFileSystem(fdir=full_path("afs"), value_conv='oidcmsg.util.JSON') + # add a client + abf['client_1'] = CLIENT_1 + # add another one + abf['client_2'] = CLIENT_2 + + afs_2 = AbstractFileSystem(fdir=full_path("afs"), value_conv='oidcmsg.util.JSON') + assert set(afs_2.keys()) == {"client_1", "client_2"} + + def test_dump(self): + abf = AbstractFileSystem(fdir=full_path("afs"), value_conv='oidcmsg.util.JSON') + # add a client + abf['client_1'] = CLIENT_1 + # add another one + abf['client_2'] = CLIENT_2 + + _dict = abf.dump() + assert _dict["client_1"]["client_secret"] == "hemligtkodord" + assert _dict["client_2"]["client_secret"] == "spraket" + + def test_dump_load(self): + abf = AbstractFileSystem(fdir=full_path("afs"), value_conv='oidcmsg.util.JSON') + # add a client + abf['client_1'] = CLIENT_1 + # add another one + abf['client_2'] = CLIENT_2 + + _dict = abf.dump() + afs_2 = AbstractFileSystem(fdir=full_path("afs"), value_conv='oidcmsg.util.JSON') + afs_2.load(_dict) + assert set(afs_2.keys()) == {"client_1", "client_2"} + + def test_dump_load_afs(self): + b = ImpExpTest() + b.string = "foo" + b.list = ["a", "b", "c"] + b.dict = AbstractFileSystem(fdir=full_path("afs"), value_conv='oidcmsg.util.JSON') + + # add a client + b.dict['client_1'] = CLIENT_1 + # add another one + b.dict['client_2'] = CLIENT_2 + + dump = b.dump() + + b_copy = ImpExpTest().load(dump) + assert b_copy + assert isinstance(b_copy.dict, AbstractFileSystem) + assert set(b_copy.dict.keys()) == {"client_1", "client_2"} + + def test_dump_load_dict(self): + b = ImpExpTest() + b.string = "foo" + b.list = ["a", "b", "c"] + b.dict = {"a": 1, "b": 2, "c": 3} + + dump = b.dump() + + b_copy = ImpExpTest().load(dump) + assert b_copy + assert isinstance(b_copy.dict, dict)