importasyncioimportinspectimportosas_osimportsysimporttypingimportunittestfromasyncio.eventsimportAbstractEventLoopfromfunctoolsimportpartial,wrapsfromtypesimportModuleTypefromtypingimport(Any,Callable,Coroutine,Iterable,List,Optional,TypeVar,Union,cast,)fromunittestimportSkipTest,expectedFailure,skip,skipIf,skipUnlessfromtortoiseimportModel,Tortoise,connectionsfromtortoise.backends.base.config_generatorimportgenerate_configas_generate_configfromtortoise.exceptionsimportDBConnectionError,OperationalErrorifsys.version_info>=(3,10):fromtypingimportParamSpecelse:fromtyping_extensionsimportParamSpec__all__=("MEMORY_SQLITE","SimpleTestCase","TestCase","TruncationTestCase","IsolatedTestCase","getDBConfig","requireCapability","env_initializer","initializer","finalizer","SkipTest","expectedFailure","skip","skipIf","skipUnless","init_memory_sqlite",)_TORTOISE_TEST_DB="sqlite://:memory:"# pylint: disable=W0201expectedFailure.__doc__="""Mark test as expecting failure.On success it will be marked as unexpected success."""_CONFIG:dict={}_CONNECTIONS:dict={}_LOOP:AbstractEventLoop=None# type: ignore_MODULES:Iterable[Union[str,ModuleType]]=[]_CONN_CONFIG:dict={}
[docs]defgetDBConfig(app_label:str,modules:Iterable[Union[str,ModuleType]])->dict:""" DB Config factory, for use in testing. :param app_label: Label of the app (must be distinct for multiple apps). :param modules: List of modules to look for models in. """return_generate_config(_TORTOISE_TEST_DB,app_modules={app_label:modules},testing=True,connection_label=app_label,)
asyncdef_init_db(config:dict)->None:# Placing init outside the try block since it doesn't# establish connections to the DB eagerly.awaitTortoise.init(config)try:awaitTortoise._drop_databases()except(DBConnectionError,OperationalError):# pragma: nocoveragepassawaitTortoise.init(config,_create_db=True)awaitTortoise.generate_schemas(safe=False)def_restore_default()->None:Tortoise.apps={}connections._get_storage().update(_CONNECTIONS.copy())connections._db_config=_CONN_CONFIG.copy()Tortoise._init_apps(_CONFIG["apps"])Tortoise._inited=True
[docs]definitializer(modules:Iterable[Union[str,ModuleType]],db_url:Optional[str]=None,app_label:str="models",loop:Optional[AbstractEventLoop]=None,)->None:""" Sets up the DB for testing. Must be called as part of test environment setup. :param modules: List of modules to look for models in. :param db_url: The db_url, defaults to ``sqlite://:memory``. :param app_label: The name of the APP to initialise the modules in, defaults to "models" :param loop: Optional event loop. """# pylint: disable=W0603global_CONFIGglobal_CONNECTIONSglobal_LOOPglobal_TORTOISE_TEST_DBglobal_MODULESglobal_CONN_CONFIG_MODULES=modulesifdb_urlisnotNone:# pragma: nobranch_TORTOISE_TEST_DB=db_url_CONFIG=getDBConfig(app_label=app_label,modules=_MODULES)loop=looporasyncio.get_event_loop()_LOOP=looploop.run_until_complete(_init_db(_CONFIG))_CONNECTIONS=connections._copy_storage()_CONN_CONFIG=connections.db_config.copy()connections._clear_storage()connections.db_config.clear()Tortoise.apps={}Tortoise._inited=False
[docs]deffinalizer()->None:""" Cleans up the DB after testing. Must be called as part of the test environment teardown. """_restore_default()loop=_LOOPloop.run_until_complete(Tortoise._drop_databases())
[docs]defenv_initializer()->None:# pragma: nocoverage""" Calls ``initializer()`` with parameters mapped from environment variables. ``TORTOISE_TEST_MODULES``: A comma-separated list of modules to include *(required)* ``TORTOISE_TEST_APP``: The name of the APP to initialise the modules in *(optional)* If not provided, it will default to "models". ``TORTOISE_TEST_DB``: The db_url of the test db. *(optional*) If not provided, it will default to an in-memory SQLite DB. """modules=str(_os.environ.get("TORTOISE_TEST_MODULES","tests.testmodels")).split(",")db_url=_os.environ.get("TORTOISE_TEST_DB","sqlite://:memory:")app_label=_os.environ.get("TORTOISE_TEST_APP","models")ifnotmodules:# pragma: nocoverageraiseException("TORTOISE_TEST_MODULES envvar not defined")initializer(modules,db_url=db_url,app_label=app_label)
[docs]classSimpleTestCase(unittest.IsolatedAsyncioTestCase):""" The Tortoise base test class. This will ensure that your DB environment has a test double set up for use. An asyncio capable test class that provides some helper functions. Will run any ``test_*()`` function either as sync or async, depending on the signature of the function. If you specify ``async test_*()`` then it will run it in an event loop. Based on `asynctest <http://asynctest.readthedocs.io/>`_ """def_setupAsyncioRunner(self)->None:ifhasattr(asyncio,"Runner"):# For python3.11+runner=asyncio.Runner(debug=True,loop_factory=asyncio.get_event_loop)self._asyncioRunner=runnerdef_tearDownAsyncioRunner(self)->None:# Override runner tear down to avoid eventloop closing before testing completed.passasyncdef_setUpDB(self)->None:passasyncdef_tearDownDB(self)->None:passdef_setupAsyncioLoop(self):loop=asyncio.get_event_loop()loop.set_debug(True)self._asyncioTestLoop=loopfut=loop.create_future()self._asyncioCallsTask=loop.create_task(self._asyncioLoopRunner(fut))# type: ignoreloop.run_until_complete(fut)def_tearDownAsyncioLoop(self):loop=self._asyncioTestLoopself._asyncioTestLoop=None# type: ignoreself._asyncioCallsQueue.put_nowait(None)# type: ignoreloop.run_until_complete(self._asyncioCallsQueue.join())# type: ignore
[docs]classIsolatedTestCase(SimpleTestCase):""" An asyncio capable test class that will ensure that an isolated test db is available for each test. Use this if your test needs perfect isolation. Note to use ``{}`` as a string-replacement parameter, for your DB_URL. That will create a randomised database name. It will create and destroy a new DB instance for every test. This is obviously slow, but guarantees a fresh DB. If you define a ``tortoise_test_modules`` list, it overrides the DB setup module for the tests. """tortoise_test_modules:Iterable[Union[str,ModuleType]]=[]asyncdef_setUpDB(self)->None:awaitsuper()._setUpDB()config=getDBConfig(app_label="models",modules=self.tortoise_test_modulesor_MODULES)awaitTortoise.init(config,_create_db=True)awaitTortoise.generate_schemas(safe=False)asyncdef_tearDownDB(self)->None:awaitTortoise._drop_databases()
[docs]classTruncationTestCase(SimpleTestCase):""" An asyncio capable test class that will truncate the tables after a test. Use this when your tests contain transactions. This is slower than ``TestCase`` but faster than ``IsolatedTestCase``. Note that usage of this does not guarantee that auto-number-pks will be reset to 1. """asyncdef_setUpDB(self)->None:awaitsuper()._setUpDB()_restore_default()asyncdef_tearDownDB(self)->None:_restore_default()# TODO: This is a naive implementation: Will fail to clear M2M and non-cascade foreign keysforappinTortoise.apps.values():formodelinapp.values():quote_char=model._meta.db.query_class._builder().QUOTE_CHARawaitmodel._meta.db.execute_script(f"DELETE FROM {quote_char}{model._meta.db_table}{quote_char}"# nosec)awaitsuper()._tearDownDB()
classTransactionTestContext:__slots__=("connection","connection_name","token","uses_pool")def__init__(self,connection)->None:self.connection=connectionself.connection_name=connection.connection_nameself.uses_pool=hasattr(self.connection._parent,"_pool")asyncdefensure_connection(self)->None:is_conn_established=self.connection._connectionisnotNoneifself.uses_pool:is_conn_established=self.connection._parent._poolisnotNone# If the underlying pool/connection hasn't been established then# first create the pool/connectionifnotis_conn_established:awaitself.connection._parent.create_connection(with_db=True)ifself.uses_pool:self.connection._connection=awaitself.connection._parent._pool.acquire()else:self.connection._connection=self.connection._parent._connectionasyncdef__aenter__(self):awaitself.ensure_connection()self.token=connections.set(self.connection_name,self.connection)awaitself.connection.start()returnself.connectionasyncdef__aexit__(self,exc_type,exc_val,exc_tb)->None:awaitself.connection.rollback()ifself.uses_pool:awaitself.connection._parent._pool.release(self.connection._connection)connections.reset(self.token)
[docs]classTestCase(TruncationTestCase):""" An asyncio capable test class that will ensure that each test will be run at separate transaction that will rollback on finish. This is a fast test runner. Don't use it if your test uses transactions. """
[docs]defrequireCapability(connection_name:str="models",**conditions:Any)->Callable:""" Skip a test if the required capabilities are not matched. .. note:: The database must be initialized *before* the decorated test runs. Usage: .. code-block:: python3 @requireCapability(dialect='sqlite') async def test_run_sqlite_only(self): ... Or to conditionally skip a class: .. code-block:: python3 @requireCapability(dialect='sqlite') class TestSqlite(test.TestCase): ... :param connection_name: name of the connection to retrieve capabilities from. :param conditions: capability tests which must all pass for the test to run. """defdecorator(test_item):ifnotisinstance(test_item,type):defcheck_capabilities()->None:db=connections.get(connection_name)forkey,valinconditions.items():ifgetattr(db.capabilities,key)!=val:raiseSkipTest(f"Capability {key} != {val}")ifhasattr(asyncio,"Runner")andinspect.iscoroutinefunction(test_item):# For python3.11+@wraps(test_item)asyncdefskip_wrapper(*args,**kwargs):check_capabilities()returnawaittest_item(*args,**kwargs)else:@wraps(test_item)defskip_wrapper(*args,**kwargs):check_capabilities()returntest_item(*args,**kwargs)returnskip_wrapper# Assume a class is decoratedfuncs={var:fforvarindir(test_item)ifvar.startswith("test_")andcallable(f:=getattr(test_item,var))}forname,funcinfuncs.items():setattr(test_item,name,requireCapability(connection_name=connection_name,**conditions)(func),)returntest_itemreturndecorator
[docs]definit_memory_sqlite(models:Union[ModulesConfigType,AsyncFunc,None]=None)->Union[AsyncFunc,AsyncFuncDeco]:""" For single file style to run code with memory sqlite :param models: list_of_modules that should be discovered for models, default to ['__main__']. Usage: .. code-block:: python3 from tortoise import fields, models, run_async from tortoise.contrib.test import init_memory_sqlite class MyModel(models.Model): id = fields.IntField(primary_key=True) name = fields.TextField() @init_memory_sqlite async def run(): obj = await MyModel.create(name='') assert obj.id == 1 if __name__ == '__main__' run_async(run) Custom models example: .. code-block:: python3 @init_memory_sqlite(models=['app.models', 'aerich.models']) async def run(): ... """defwrapper(func:AsyncFunc,ms:List[str]):@wraps(func)asyncdefrunner(*args,**kwargs)->T:awaitTortoise.init(db_url=MEMORY_SQLITE,modules={"models":ms})awaitTortoise.generate_schemas()returnawaitfunc(*args,**kwargs)returnrunnerdefault_models=["__main__"]ifinspect.iscoroutinefunction(models):returnwrapper(models,default_models)ifmodelsisNone:models=default_modelselifisinstance(models,str):models=[models]else:models=cast(list,models)returnpartial(wrapper,ms=models)