Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
celery.security.certificate 源代码 """X.509 certificates."""
from __future__ import annotations
import datetime
import glob
import os
from typing import TYPE_CHECKING , Iterator
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import padding , rsa
from cryptography.x509 import load_pem_x509_certificate
from kombu.utils.encoding import bytes_to_str , ensure_bytes
from celery.exceptions import SecurityError
from .utils import reraise_errors
if TYPE_CHECKING :
from cryptography.hazmat.primitives.asymmetric.dsa import DSAPublicKey
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey
from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PublicKey
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from cryptography.hazmat.primitives.asymmetric.utils import Prehashed
from cryptography.hazmat.primitives.hashes import HashAlgorithm
__all__ = ( 'Certificate' , 'CertStore' , 'FSCertStore' )
[文档]
class Certificate :
"""X.509 certificate."""
def __init__ ( self , cert : str ) -> None :
with reraise_errors (
'Invalid certificate: {0!r} ' , errors = ( ValueError ,)
):
self . _cert = load_pem_x509_certificate (
ensure_bytes ( cert ), backend = default_backend ())
if not isinstance ( self . _cert . public_key (), rsa . RSAPublicKey ):
raise ValueError ( "Non-RSA certificates are not supported." )
[文档]
def has_expired ( self ) -> bool :
"""Check if the certificate has expired."""
return datetime . datetime . now ( datetime . timezone . utc ) >= self . _cert . not_valid_after_utc
[文档]
def get_pubkey ( self ) -> (
DSAPublicKey | EllipticCurvePublicKey | Ed448PublicKey | Ed25519PublicKey | RSAPublicKey
):
return self . _cert . public_key ()
[文档]
def get_serial_number ( self ) -> int :
"""Return the serial number in the certificate."""
return self . _cert . serial_number
[文档]
def get_issuer ( self ) -> str :
"""Return issuer (CA) as a string."""
return ' ' . join ( x . value for x in self . _cert . issuer )
[文档]
def get_id ( self ) -> str :
"""Serial number/issuer pair uniquely identifies a certificate."""
return f ' { self . get_issuer () } { self . get_serial_number () } '
[文档]
def verify ( self , data : bytes , signature : bytes , digest : HashAlgorithm | Prehashed ) -> None :
"""Verify signature for string containing data."""
with reraise_errors ( 'Bad signature: {0!r} ' ):
pad = padding . PSS (
mgf = padding . MGF1 ( digest ),
salt_length = padding . PSS . MAX_LENGTH )
self . get_pubkey () . verify ( signature , ensure_bytes ( data ), pad , digest )
[文档]
class CertStore :
"""Base class for certificate stores."""
def __init__ ( self ) -> None :
self . _certs : dict [ str , Certificate ] = {}
[文档]
def itercerts ( self ) -> Iterator [ Certificate ]:
"""Return certificate iterator."""
yield from self . _certs . values ()
def __getitem__ ( self , id : str ) -> Certificate :
"""Get certificate by id."""
try :
return self . _certs [ bytes_to_str ( id )]
except KeyError :
raise SecurityError ( f 'Unknown certificate: { id !r} ' )
[文档]
def add_cert ( self , cert : Certificate ) -> None :
cert_id = bytes_to_str ( cert . get_id ())
if cert_id in self . _certs :
raise SecurityError ( f 'Duplicate certificate: { id !r} ' )
self . _certs [ cert_id ] = cert
[文档]
class FSCertStore ( CertStore ):
"""File system certificate store."""
def __init__ ( self , path : str ) -> None :
super () . __init__ ()
if os . path . isdir ( path ):
path = os . path . join ( path , '*' )
for p in glob . glob ( path ):
with open ( p ) as f :
cert = Certificate ( f . read ())
if cert . has_expired ():
raise SecurityError (
f 'Expired certificate: { cert . get_id () !r} ' )
self . add_cert ( cert )
复制到剪贴板