"""AWS DynamoDB result store backend."""
from collections import namedtuple
from ipaddress import ip_address
from time import sleep, time
from typing import Any, Dict
from kombu.utils.url import _parse_url as parse_url
from celery.exceptions import ImproperlyConfigured
from celery.utils.log import get_logger
from .base import KeyValueStoreBackend
try:
    import boto3
    from botocore.exceptions import ClientError
except ImportError:
    boto3 = ClientError = None
__all__ = ('DynamoDBBackend',)
# Helper class that describes a DynamoDB attribute
DynamoDBAttribute = namedtuple('DynamoDBAttribute', ('name', 'data_type'))
logger = get_logger(__name__)
[文档]
class DynamoDBBackend(KeyValueStoreBackend):
    """AWS DynamoDB result backend.
    Raises:
        celery.exceptions.ImproperlyConfigured:
            if module :pypi:`boto3` is not available.
    """
    #: default DynamoDB table name (`default`)
    table_name = 'celery'
    #: Read Provisioned Throughput (`default`)
    read_capacity_units = 1
    #: Write Provisioned Throughput (`default`)
    write_capacity_units = 1
    #: AWS region (`default`)
    aws_region = None
    #: The endpoint URL that is passed to boto3 (local DynamoDB) (`default`)
    endpoint_url = None
    #: Item time-to-live in seconds (`default`)
    time_to_live_seconds = None
    # DynamoDB supports Time to Live as an auto-expiry mechanism.
    supports_autoexpire = True
    _key_field = DynamoDBAttribute(name='id', data_type='S')
    # Each record has either a value field or count field
    _value_field = DynamoDBAttribute(name='result', data_type='B')
    _count_filed = DynamoDBAttribute(name="chord_count", data_type='N')
    _timestamp_field = DynamoDBAttribute(name='timestamp', data_type='N')
    _ttl_field = DynamoDBAttribute(name='ttl', data_type='N')
    _available_fields = None
    implements_incr = True
    def __init__(self, url=None, table_name=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.url = url
        self.table_name = table_name or self.table_name
        if not boto3:
            raise ImproperlyConfigured(
                'You need to install the boto3 library to use the '
                'DynamoDB backend.')
        aws_credentials_given = False
        aws_access_key_id = None
        aws_secret_access_key = None
        if url is not None:
            scheme, region, port, username, password, table, query = \
                
parse_url(url)
            aws_access_key_id = username
            aws_secret_access_key = password
            access_key_given = aws_access_key_id is not None
            secret_key_given = aws_secret_access_key is not None
            if access_key_given != secret_key_given:
                raise ImproperlyConfigured(
                    'You need to specify both the Access Key ID '
                    'and Secret.')
            aws_credentials_given = access_key_given
            if region == 'localhost' or DynamoDBBackend._is_valid_ip(region):
                # We are using the downloadable, local version of DynamoDB
                self.endpoint_url = f'http://{region}:{port}'
                self.aws_region = 'us-east-1'
                logger.warning(
                    'Using local-only DynamoDB endpoint URL: {}'.format(
                        self.endpoint_url
                    )
                )
            else:
                self.aws_region = region
            # If endpoint_url is explicitly set use it instead
            _get = self.app.conf.get
            config_endpoint_url = _get('dynamodb_endpoint_url')
            if config_endpoint_url:
                self.endpoint_url = config_endpoint_url
            self.read_capacity_units = int(
                query.get(
                    'read',
                    self.read_capacity_units
                )
            )
            self.write_capacity_units = int(
                query.get(
                    'write',
                    self.write_capacity_units
                )
            )
            ttl = query.get('ttl_seconds', self.time_to_live_seconds)
            if ttl:
                try:
                    self.time_to_live_seconds = int(ttl)
                except ValueError as e:
                    logger.error(
                        f'TTL must be a number; got "{ttl}"',
                        exc_info=e
                    )
                    raise e
            self.table_name = table or self.table_name
        self._available_fields = (
            self._key_field,
            self._value_field,
            self._timestamp_field
        )
        self._client = None
        if aws_credentials_given:
            self._get_client(
                access_key_id=aws_access_key_id,
                secret_access_key=aws_secret_access_key
            )
    @staticmethod
    def _is_valid_ip(ip):
        try:
            ip_address(ip)
            return True
        except ValueError:
            return False
    def _get_client(self, access_key_id=None, secret_access_key=None):
        """Get client connection."""
        if self._client is None:
            client_parameters = {
                'region_name': self.aws_region
            }
            if access_key_id is not None:
                client_parameters.update({
                    'aws_access_key_id': access_key_id,
                    'aws_secret_access_key': secret_access_key
                })
            if self.endpoint_url is not None:
                client_parameters['endpoint_url'] = self.endpoint_url
            self._client = boto3.client(
                'dynamodb',
                **client_parameters
            )
            self._get_or_create_table()
            if self._has_ttl() is not None:
                self._validate_ttl_methods()
                self._set_table_ttl()
        return self._client
    def _get_table_schema(self):
        """Get the boto3 structure describing the DynamoDB table schema."""
        return {
            'AttributeDefinitions': [
                {
                    'AttributeName': self._key_field.name,
                    'AttributeType': self._key_field.data_type
                }
            ],
            'TableName': self.table_name,
            'KeySchema': [
                {
                    'AttributeName': self._key_field.name,
                    'KeyType': 'HASH'
                }
            ],
            'ProvisionedThroughput': {
                'ReadCapacityUnits': self.read_capacity_units,
                'WriteCapacityUnits': self.write_capacity_units
            }
        }
    def _get_or_create_table(self):
        """Create table if not exists, otherwise return the description."""
        table_schema = self._get_table_schema()
        try:
            return self._client.describe_table(TableName=self.table_name)
        except ClientError as e:
            error_code = e.response['Error'].get('Code', 'Unknown')
            if error_code == 'ResourceNotFoundException':
                table_description = self._client.create_table(**table_schema)
                logger.info(
                    'DynamoDB Table {} did not exist, creating.'.format(
                        self.table_name
                    )
                )
                # In case we created the table, wait until it becomes available.
                self._wait_for_table_status('ACTIVE')
                logger.info(
                    'DynamoDB Table {} is now available.'.format(
                        self.table_name
                    )
                )
                return table_description
            else:
                raise e
    def _has_ttl(self):
        """Return the desired Time to Live config.
        - True:  Enable TTL on the table; use expiry.
        - False: Disable TTL on the table; don't use expiry.
        - None:  Ignore TTL on the table; don't use expiry.
        """
        return None if self.time_to_live_seconds is None \
            
else self.time_to_live_seconds >= 0
    def _validate_ttl_methods(self):
        """Verify boto support for the DynamoDB Time to Live methods."""
        # Required TTL methods.
        required_methods = (
            'update_time_to_live',
            'describe_time_to_live',
        )
        # Find missing methods.
        missing_methods = []
        for method in list(required_methods):
            if not hasattr(self._client, method):
                missing_methods.append(method)
        if missing_methods:
            logger.error(
                (
                    'boto3 method(s) {methods} not found; ensure that '
                    'boto3>=1.9.178 and botocore>=1.12.178 are installed'
                ).format(
                    methods=','.join(missing_methods)
                )
            )
            raise AttributeError(
                'boto3 method(s) {methods} not found'.format(
                    methods=','.join(missing_methods)
                )
            )
    def _get_ttl_specification(self, ttl_attr_name):
        """Get the boto3 structure describing the DynamoDB TTL specification."""
        return {
            'TableName': self.table_name,
            'TimeToLiveSpecification': {
                'Enabled': self._has_ttl(),
                'AttributeName': ttl_attr_name
            }
        }
    def _get_table_ttl_description(self):
        # Get the current TTL description.
        try:
            description = self._client.describe_time_to_live(
                TableName=self.table_name
            )
        except ClientError as e:
            error_code = e.response['Error'].get('Code', 'Unknown')
            error_message = e.response['Error'].get('Message', 'Unknown')
            logger.error((
                'Error describing Time to Live on DynamoDB table {table}: '
                '{code}: {message}'
            ).format(
                table=self.table_name,
                code=error_code,
                message=error_message,
            ))
            raise e
        return description
    def _set_table_ttl(self):
        """Enable or disable Time to Live on the table."""
        # Get the table TTL description, and return early when possible.
        description = self._get_table_ttl_description()
        status = description['TimeToLiveDescription']['TimeToLiveStatus']
        if status in ('ENABLED', 'ENABLING'):
            cur_attr_name = \
                
description['TimeToLiveDescription']['AttributeName']
            if self._has_ttl():
                if cur_attr_name == self._ttl_field.name:
                    # We want TTL enabled, and it is currently enabled or being
                    # enabled, and on the correct attribute.
                    logger.debug((
                        'DynamoDB Time to Live is {situation} '
                        'on table {table}'
                    ).format(
                        situation='already enabled'
                        if status == 'ENABLED'
                        else 'currently being enabled',
                        table=self.table_name
                    ))
                    return description
        elif status in ('DISABLED', 'DISABLING'):
            if not self._has_ttl():
                # We want TTL disabled, and it is currently disabled or being
                # disabled.
                logger.debug((
                    'DynamoDB Time to Live is {situation} '
                    'on table {table}'
                ).format(
                    situation='already disabled'
                    if status == 'DISABLED'
                    else 'currently being disabled',
                    table=self.table_name
                ))
                return description
        # The state shouldn't ever have any value beyond the four handled
        # above, but to ease troubleshooting of potential future changes, emit
        # a log showing the unknown state.
        else:  # pragma: no cover
            logger.warning((
                'Unknown DynamoDB Time to Live status {status} '
                'on table {table}. Attempting to continue.'
            ).format(
                status=status,
                table=self.table_name
            ))
        # At this point, we have one of the following situations:
        #
        # We want TTL enabled,
        #
        # - and it's currently disabled: Try to enable.
        #
        # - and it's being disabled: Try to enable, but this is almost sure to
        #   raise ValidationException with message:
        #
        #     Time to live has been modified multiple times within a fixed
        #     interval
        #
        # - and it's currently enabling or being enabled, but on the wrong
        #   attribute: Try to enable, but this will raise ValidationException
        #   with message:
        #
        #     TimeToLive is active on a different AttributeName: current
        #     AttributeName is ttlx
        #
        # We want TTL disabled,
        #
        # - and it's currently enabled: Try to disable.
        #
        # - and it's being enabled: Try to disable, but this is almost sure to
        #   raise ValidationException with message:
        #
        #     Time to live has been modified multiple times within a fixed
        #     interval
        #
        attr_name = \
            
cur_attr_name if status == 'ENABLED' else self._ttl_field.name
        try:
            specification = self._client.update_time_to_live(
                **self._get_ttl_specification(
                    ttl_attr_name=attr_name
                )
            )
            logger.info(
                (
                    'DynamoDB table Time to Live updated: '
                    'table={table} enabled={enabled} attribute={attr}'
                ).format(
                    table=self.table_name,
                    enabled=self._has_ttl(),
                    attr=self._ttl_field.name
                )
            )
            return specification
        except ClientError as e:
            error_code = e.response['Error'].get('Code', 'Unknown')
            error_message = e.response['Error'].get('Message', 'Unknown')
            logger.error((
                'Error {action} Time to Live on DynamoDB table {table}: '
                '{code}: {message}'
            ).format(
                action='enabling' if self._has_ttl() else 'disabling',
                table=self.table_name,
                code=error_code,
                message=error_message,
            ))
            raise e
    def _wait_for_table_status(self, expected='ACTIVE'):
        """Poll for the expected table status."""
        achieved_state = False
        while not achieved_state:
            table_description = self.client.describe_table(
                TableName=self.table_name
            )
            logger.debug(
                'Waiting for DynamoDB table {} to become {}.'.format(
                    self.table_name,
                    expected
                )
            )
            current_status = table_description['Table']['TableStatus']
            achieved_state = current_status == expected
            sleep(1)
    def _prepare_get_request(self, key):
        """Construct the item retrieval request parameters."""
        return {
            'TableName': self.table_name,
            'Key': {
                self._key_field.name: {
                    self._key_field.data_type: key
                }
            }
        }
    def _prepare_put_request(self, key, value):
        """Construct the item creation request parameters."""
        timestamp = time()
        put_request = {
            'TableName': self.table_name,
            'Item': {
                self._key_field.name: {
                    self._key_field.data_type: key
                },
                self._value_field.name: {
                    self._value_field.data_type: value
                },
                self._timestamp_field.name: {
                    self._timestamp_field.data_type: str(timestamp)
                }
            }
        }
        if self._has_ttl():
            put_request['Item'].update({
                self._ttl_field.name: {
                    self._ttl_field.data_type:
                        str(int(timestamp + self.time_to_live_seconds))
                }
            })
        return put_request
    def _prepare_init_count_request(self, key: str) -> Dict[str, Any]:
        """Construct the counter initialization request parameters"""
        timestamp = time()
        return {
            'TableName': self.table_name,
            'Item': {
                self._key_field.name: {
                    self._key_field.data_type: key
                },
                self._count_filed.name: {
                    self._count_filed.data_type: "0"
                },
                self._timestamp_field.name: {
                    self._timestamp_field.data_type: str(timestamp)
                }
            }
        }
    def _prepare_inc_count_request(self, key: str) -> Dict[str, Any]:
        """Construct the counter increment request parameters"""
        return {
            'TableName': self.table_name,
            'Key': {
                self._key_field.name: {
                    self._key_field.data_type: key
                }
            },
            'UpdateExpression': f"set {self._count_filed.name} = {self._count_filed.name} + :num",
            "ExpressionAttributeValues": {
                ":num": {"N": "1"},
            },
            "ReturnValues": "UPDATED_NEW",
        }
    def _item_to_dict(self, raw_response):
        """Convert get_item() response to field-value pairs."""
        if 'Item' not in raw_response:
            return {}
        return {
            field.name: raw_response['Item'][field.name][field.data_type]
            for field in self._available_fields
        }
    @property
    def client(self):
        return self._get_client()
[文档]
    def get(self, key):
        key = str(key)
        request_parameters = self._prepare_get_request(key)
        item_response = self.client.get_item(**request_parameters)
        item = self._item_to_dict(item_response)
        return item.get(self._value_field.name) 
[文档]
    def set(self, key, value):
        key = str(key)
        request_parameters = self._prepare_put_request(key, value)
        self.client.put_item(**request_parameters) 
[文档]
    def mget(self, keys):
        return [self.get(key) for key in keys] 
[文档]
    def delete(self, key):
        key = str(key)
        request_parameters = self._prepare_get_request(key)
        self.client.delete_item(**request_parameters) 
[文档]
    def incr(self, key: bytes) -> int:
        """Atomically increase the chord_count and return the new count"""
        key = str(key)
        request_parameters = self._prepare_inc_count_request(key)
        item_response = self.client.update_item(**request_parameters)
        new_count: str = item_response["Attributes"][self._count_filed.name][self._count_filed.data_type]
        return int(new_count) 
    def _apply_chord_incr(self, header_result_args, body, **kwargs):
        chord_key = self.get_key_for_chord(header_result_args[0])
        init_count_request = self._prepare_init_count_request(str(chord_key))
        self.client.put_item(**init_count_request)
        return super()._apply_chord_incr(
            header_result_args, body, **kwargs)