fromfunctoolsimportwrapsfromtypingimportTYPE_CHECKING,Any,Callable,Optional,TypeVar,castfromtortoiseimportconnectionsfromtortoise.exceptionsimportParamsErrorifTYPE_CHECKING:# pragma: nocoveragefromtortoise.backends.base.clientimportBaseDBAsyncClient,TransactionContextFuncType=Callable[...,Any]F=TypeVar("F",bound=FuncType)def_get_connection(connection_name:Optional[str])->"BaseDBAsyncClient":ifconnection_name:connection=connections.get(connection_name)eliflen(connections.db_config)==1:connection_name=next(iter(connections.db_config.keys()))connection=connections.get(connection_name)else:raiseParamsError("You are running with multiple databases, so you should specify"f" connection_name: {list(connections.db_config)}")returnconnection
[docs]defin_transaction(connection_name:Optional[str]=None)->"TransactionContext":""" Transaction context manager. You can run your code inside ``async with in_transaction():`` statement to run it into one transaction. If error occurs transaction will rollback. :param connection_name: name of connection to run with, optional if you have only one db connection """connection=_get_connection(connection_name)returnconnection._in_transaction()
[docs]defatomic(connection_name:Optional[str]=None)->Callable[[F],F]:""" Transaction decorator. You can wrap your function with this decorator to run it into one transaction. If error occurs transaction will rollback. :param connection_name: name of connection to run with, optional if you have only one db connection """defwrapper(func:F)->F:@wraps(func)asyncdefwrapped(*args,**kwargs):asyncwithin_transaction(connection_name):returnawaitfunc(*args,**kwargs)returncast(F,wrapped)returnwrapper