节流概念

节流的相关源码在GenericAPIView中的APIView中的dispatch中的initial中的self`.check_throttles(request)`

check_throttles(request)

    def check_throttles(self, request):
        """
        Check if request should be throttled.
        Raises an appropriate exception if the request is throttled.
        """
        throttle_durations = []
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                throttle_durations.append(throttle.wait())  

        if throttle_durations:
            # Filter out `None` values which may happen in case of config / rate
            # changes, see #1438
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]

            duration = max(durations, default=None)
            self.throttled(request, duration)
  • 第7行:如果不想允许请求就将其放入列表中throttle_durations
  • 11-20行:检测到throttle_durations列表不为空,说明有请求被节流,所以抛出异常

节流功能分析

根据上面的源码分析,要实现节流功能,就实现get_throttles()中的方法。

get_throttles()

    def get_throttles(self):
        """
        Instantiates and returns the list of throttles that this view uses.
        """
        return [throttle() for throttle in self.throttle_classes]

观察throttle_classes

throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES

实际上默认节流就是api_settings中的default中的节流配置
观察DEFAULT_THROTTLE_CLASSES:

'DEFAULT_THROTTLE_CLASSES': [],

可以看到源码中给的是空的列表,也就是说没有节流。

Throttle.py分析

BaseThrottle

class BaseThrottle:
    """
    Rate throttling of requests.
    """

    def allow_request(self, request, view):
        """
        Return `True` if the request should be allowed, `False` otherwise.
        """
        raise NotImplementedError('.allow_request() must be overridden')

    def get_ident(self, request):  # 获取唯一标识
        """
        Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
        if present and number of proxies is > 0. If not use all of
        HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
        """
        xff = request.META.get('HTTP_X_FORWARDED_FOR')  # 获取原始ip
        remote_addr = request.META.get('REMOTE_ADDR')
        num_proxies = api_settings.NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr # 如果有原始ip就返回给原始ip,否则返回给代理ip

    def wait(self):
        """
        Optionally, return a recommended number of seconds to wait before
        the next request.
        """
        return None
  • 6-10行 allow_request()使用allow_request必须要重写实现
  • 12行       get_ident()获取唯一标识

    • HTTP_X_FORWARDED_FOR

      • 获取原始ip
    • REMOTE_ADDR

      • 通过普通的代理发送的请求,获取REMOTE_ADDR获取到的是代理IP
    • 补充:代理

      • 普通代理
      • 高匿代理

        • 越是加密,导致效率越低,请求速度越慢

    SimpleRateThrottle

    找到SimpleRateThrottle所在位置:

可以看到SimpleRateThrottle继承自BaseThrottle

源码:

class SimpleRateThrottle(BaseThrottle):
    """
    A simple cache implementation, that only requires `.get_cache_key()`
    to be overridden.

    The rate (requests / seconds) is set by a `rate` attribute on the View
    class.  The attribute is a string of the form 'number_of_requests/period'.

    Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')

    Previous request information used for throttling is stored in the cache.
    """
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

    def get_cache_key(self, request, view):
        """
        Should return a unique cache-key which can be used for throttling.
        Must be overridden.

        May return `None` if the request should not be throttled.
        """
        raise NotImplementedError('.get_cache_key() must be overridden')

    def get_rate(self):
        """
        Determine the string representation of the allowed request rate.
        """
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        <allowed number of requests>, <period of time in seconds>
        """
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

    def allow_request(self, request, view):
        """
        Implement the check to see if the request should be throttled.

        On success calls `throttle_success`.
        On failure calls `throttle_failure`.
        """
        if self.rate is None:  # 如果频率为空,允许访问
            return True

        self.key = self.get_cache_key(request, view)  # 获取缓存key
        if self.key is None:  # 如果缓存key是空,也是允许的
            return True

        self.history = self.cache.get(self.key, [])
        self.now = self.timer()  # 获取时间

        # Drop any requests from the history which have now passed the
        # throttle duration
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()  # 清理过期访问次数记录
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()  # 如果超过请求数量就返回失败
        return self.throttle_success()

    def throttle_success(self):
        """
        Inserts the current request's timestamp along with the key
        into the cache.
        """
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
        """
        Called when a request to the API has failed due to throttling.
        """
        return False

    def wait(self):
        """
        Returns the recommended next request time in seconds.
        """
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)
  • get_cache_key

    • 获取缓存key,需要重写实现
  • get_rate

    • 获取访问频率
  • parse_rate

    • 转换频率
  • allow_request

    • 这里重写了BaseThrottle中的allow_request方法,详细看上方代码注释
    • 类似于之前笔记中写过的频率限制中间件

AnonRateThrottle

在源码中查找AnonRateThrottle:

class AnonRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls that may be made by a anonymous users.

    The IP address of the request will be used as the unique cache key.
    """
    scope = 'anon'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            return None  # Only throttle unauthenticated requests.

        return self.cache_format % {
            'scope': self.scope,
            'ident': self.get_ident(request)
        }

可以看到AnonRateThrottle继承自SimpleRateThrottle


UserRateThrottle

class UserRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls that may be made by a given user.

    The user id will be used as a unique cache key if the user is
    authenticated.  For anonymous requests, the IP address of the request will
    be used.
    """
    scope = 'user'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated: # 如果用户是经过认证的
            ident = request.user.pk  # 就获得用户的主键
        else:
            ident = self.get_ident(request)  # 如果没有主键就将get_ident作为唯一标识

        return self.cache_format % {  # 格式化 cache_format = 'throttle_%(scope)s_%(ident)s'
            'scope': self.scope,
            'ident': ident
        }  # 相当于返回key

Setting the throttling policy (设定限流策略)

  • 实现全局节流,频率限制

    • 实现全局在settings中配置

创建throttles.py

创建throttle子类

from rest_framework.throttling import SimpleRateThrottle

from App.models import UserModel

class UserThrottle(SimpleRateThrottle):
    scope = 'user'
    def get_cache_key(self, request, view):  # 重写方法
        if isinstance(request.user, UserModel):  # 如果是模型的实例
            ident = request.auth  # 使用auth属性作为唯一标识
        else:
            ident = self.get_ident(request)

        return self.cache_format % {
            'scope': self.scope,
            'ident': ident
        }

通过使用 DEFAULT_THROTTLE_CLASSES 和 DEFAULT_THROTTLE_RATES 默认限流策略将被全局设定,例如:

REST_FRAMEWORK={
    'DEFAULT_THROTTLE_CLASSES': (
        'App.throttle.UserRateThrottle',  # 需要节流的类
    ),
    'DEFAULT_THROTTLE_RATES':{  # 频率限制
        'user': "5/m"  # 限制用户一分钟5次,user对应scope的值
    }

测试



连续访问,成功被限制。

如果想要针对某个请求方法,某个视图函数都可以定制

某个view视图类单独节流

指针对单个的视图类时可以直接在views中的class中添加throttle_class即可。

节流器总结

  • BaseThrottle

    • allow_request

      • 是否允许的请求的核心
    • get_ident

      • 获取客户端唯一标识
    • wait
  • SimpleRateThrottle

    • get_cache_key

      • 获取缓存标识
    • get_rate

      • 获取频率
    • parse_rate

      • 转换频率

        • num/duration

          • duration

            • s 秒
            • m 分钟
            • h 小时
            • d 天
    • allow_request

      • 是否允许请求
      • 重写的方法
    • throttle_success

      • 允许请求,进行请求记录
    • throttle_failure

      • 不允许请求
    • wait

      • 还有多少时间之后允许访问
  • AnonRateThrottle

    • get_cache_key

      • 获取缓存key的原则
  • UserRateThrottle

    • 结构同上
  • ScopedRateThrottle

    • 结构同上
    • 多写了从属性中获取频率,这样可以在views.py中对应的class中写throttle_scope也就是限制频率,例如10/m
最后修改:2024 年 03 月 13 日
如果觉得我的文章对你有用,请随意赞赏