Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api/api-app/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@
<artifactId>micrometer-registry-prometheus</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>com.github.vladimir-bukhtoyarov</groupId>
<artifactId>bucket4j-core</artifactId>
<version>8.0.1</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,39 @@
package co.nilin.opex.api.app.config

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.registerKotlinModule
import org.springframework.cache.CacheManager
import org.springframework.cache.annotation.EnableCaching
import org.springframework.cache.concurrent.ConcurrentMapCacheManager
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.data.redis.connection.RedisConnectionFactory
import org.springframework.data.redis.core.RedisTemplate
import org.springframework.data.redis.serializer.GenericJackson2JsonRedisSerializer
import org.springframework.data.redis.serializer.StringRedisSerializer

@Configuration
@EnableCaching
class CacheConfig {

@Bean
fun redisTemplate(connectionFactory: RedisConnectionFactory, mapper: ObjectMapper): RedisTemplate<String, Any> {
val newMapper = mapper.copy().apply {
activateDefaultTyping(mapper.polymorphicTypeValidator, ObjectMapper.DefaultTyping.EVERYTHING)
findAndRegisterModules()
registerKotlinModule()
}
return RedisTemplate<String, Any>().apply {
setConnectionFactory(connectionFactory)
val ser = GenericJackson2JsonRedisSerializer(newMapper)
valueSerializer = ser
hashValueSerializer = ser
keySerializer = StringRedisSerializer()
hashKeySerializer = StringRedisSerializer()
afterPropertiesSet()
}
}

@Bean
fun apiKeyCacheManager(): CacheManager {
return ConcurrentMapCacheManager("apiKey")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package co.nilin.opex.api.app.config

import co.nilin.opex.api.app.service.RateLimitCoordinatorService
import co.nilin.opex.api.core.inout.RateLimitEndpoint
import co.nilin.opex.api.core.spi.RateLimitConfigService
import org.springframework.http.HttpStatus
import org.springframework.security.core.context.ReactiveSecurityContextHolder
import org.springframework.stereotype.Component
import org.springframework.web.server.ServerWebExchange
import org.springframework.web.server.WebFilter
import org.springframework.web.server.WebFilterChain
import org.springframework.web.util.pattern.PathPatternParser
import reactor.core.publisher.Mono

@Component
class RateLimitConfig(
private val rateLimitConfig: RateLimitConfigService,
private val coordinator: RateLimitCoordinatorService

) : WebFilter {
private val parser = PathPatternParser()

override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono<Void> {

val endpoint = rateLimitConfig.getEndpoints()
.asSequence()
.filter { it.enabled }
.filter { it.method.equals(exchange.request.method.name(), true) }
.sortedByDescending { it.priority }
.firstOrNull { endpoint ->
val pattern = parser.parse(endpoint.url)
pattern.matches(exchange.request.path)
}

if (endpoint == null) {
return chain.filter(exchange)
}

return applyRateLimitIfAuthenticated(exchange, chain, endpoint)
}


private fun applyRateLimitIfAuthenticated(
exchange: ServerWebExchange,
chain: WebFilterChain,
endpoint: RateLimitEndpoint
): Mono<Void> {

return ReactiveSecurityContextHolder.getContext()
.mapNotNull { it.authentication }
.filter { it.isAuthenticated }
.flatMap { auth ->
if (auth != null && !auth.name.isNullOrBlank())
applyRateLimit(auth.name, exchange, chain, endpoint)
else
chain.filter(exchange)
}

}


private fun applyRateLimit(
identity: String,
exchange: ServerWebExchange,
chain: WebFilterChain,
endpoint: RateLimitEndpoint
): Mono<Void> {

val group = rateLimitConfig.getGroup(endpoint.groupId)
?: return chain.filter(exchange)

val result = coordinator.check(
identity = identity,
groupId = endpoint.groupId,
maxRequests = group.requestCount,
windowSeconds = group.requestWindowSeconds,
apiPath = endpoint.url,
apiMethod = endpoint.method
)

return if (result.blocked) {
tooManyRequests(exchange, identity, endpoint.url, endpoint.method, result.retryAfterSeconds)
} else {
chain.filter(exchange)
}
}

//TODO should throw opex error
private fun tooManyRequests(
exchange: ServerWebExchange,
identity: String,
url: String,
method: String,
retryAfterSeconds: Int
): Mono<Void> {
exchange.response.statusCode = HttpStatus.TOO_MANY_REQUESTS
return exchange.response.writeWith(
Mono.just(
exchange.response.bufferFactory()
.wrap("Rate limit exceeded ($identity) -- $method:$url -- Retry-After, $retryAfterSeconds".toByteArray())
)
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package co.nilin.opex.api.app.config

import co.nilin.opex.api.core.spi.RateLimitConfigService
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import org.springframework.boot.context.event.ApplicationReadyEvent
import org.springframework.context.event.EventListener
import org.springframework.stereotype.Component

@Component
class RateLimitConfigLoader(
private val rateLimitConfig: RateLimitConfigService
) {
@EventListener(ApplicationReadyEvent::class)
fun preload() {
CoroutineScope(Dispatchers.Default).launch {
rateLimitConfig.loadConfig()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package co.nilin.opex.api.app.controller

import co.nilin.opex.api.core.spi.RateLimitConfigService
import org.springframework.web.bind.annotation.PostMapping
import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RestController

@RestController
@RequestMapping("/v1/rate-limit")
class RateLimitController(
private val rateLimitConfig: RateLimitConfigService,
) {
@PostMapping
suspend fun reloadRateLimits() {
rateLimitConfig.loadConfig()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package co.nilin.opex.api.app.data

data class BlockResult(
val blocked: Boolean,
val retryAfterSeconds: Int = 0
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package co.nilin.opex.api.app.data

data class RateLimitPenaltyState(
var violationCount: Int = 0,
var lastViolationAt: Long? = null,
var bannedUntil: Long? = null
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package co.nilin.opex.api.app.data

data class RateLimitState(
var violationCount: Int = 0,
var blockedUntil: Long? = null,
var lastViolationAt: Long? = null,
var graceRemaining: Int = 0
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package co.nilin.opex.api.app.service

import co.nilin.opex.api.app.data.BlockResult
import org.springframework.stereotype.Component

@Component
class RateLimitCoordinatorService(
private val rateLimiterService: RateLimiterService,
private val penaltyService: RateLimitPenaltyService
) {


fun check(
identity: String,
groupId: Long,
maxRequests: Int,
windowSeconds: Int,
apiPath: String,
apiMethod: String
): BlockResult {

val blocked = penaltyService.isBlocked(identity, apiPath, apiMethod)
if (blocked.blocked) {
return blocked
}

val allowed = rateLimiterService.checkRateLimit(
identity = identity,
maxRequests = maxRequests,
windowInSeconds = windowSeconds,
apiPath = apiPath,
apiMethod = apiMethod
)

return if (allowed) {
penaltyService.onAllowed(identity, groupId, apiPath, apiMethod)
BlockResult(blocked = false)
} else {
penaltyService.onLimit(identity, groupId, apiPath, apiMethod)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package co.nilin.opex.api.app.service

import co.nilin.opex.api.app.data.BlockResult
import co.nilin.opex.api.app.data.RateLimitPenaltyState
import co.nilin.opex.api.core.spi.RateLimitConfigService
import co.nilin.opex.api.ports.postgres.util.RedisCacheHelper
import co.nilin.opex.common.utils.DynamicInterval
import org.springframework.stereotype.Component
import java.time.Duration
import java.util.concurrent.TimeUnit
import kotlin.math.min

@Component
class RateLimitPenaltyService(private val config: RateLimitConfigService, private val redis: RedisCacheHelper) {

fun isBlocked(identity: String, apiPath: String, apiMethod: String): BlockResult {
val state = getPenaltyState(identity, apiPath, apiMethod) ?: return BlockResult(false)

val now = System.currentTimeMillis()
val bannedUntil = state.bannedUntil ?: return BlockResult(false)

return if (bannedUntil > now) {
BlockResult(
blocked = true,
retryAfterSeconds = ((bannedUntil - now) / 1000).toInt()
)
} else {
BlockResult(false)
}
}

fun onLimit(identity: String, groupId: Long, apiPath: String, apiMethod: String): BlockResult {
val now = System.currentTimeMillis()
val group = config.getGroup(groupId) ?: return BlockResult(false)
val penalties = config.getPenalties(groupId).sortedBy { it.blockStep }

val current = getPenaltyState(identity, apiPath, apiMethod)
val nextViolationCount = (current?.violationCount ?: 0) + 1

val level = min(nextViolationCount, penalties.size)
val penalty = penalties[level - 1]

val bannedUntil = now + Duration.ofSeconds(penalty.blockDurationSeconds.toLong()).toMillis()

val newState = RateLimitPenaltyState(
violationCount = nextViolationCount,
lastViolationAt = now,
bannedUntil = bannedUntil
)

val ttl = penalty.blockDurationSeconds + group.cooldownSeconds

savePenaltyState(identity, apiPath, apiMethod, newState, ttl)

return BlockResult(
blocked = true,
retryAfterSeconds = penalty.blockDurationSeconds
)
}

fun onAllowed(identity: String, groupId: Long, apiPath: String, apiMethod: String) {
val state = getPenaltyState(identity, apiPath, apiMethod) ?: return
val group = config.getGroup(groupId) ?: return
val now = System.currentTimeMillis()

val lastViolation = state.lastViolationAt ?: return
val cooldownMillis = Duration.ofSeconds(group.cooldownSeconds.toLong()).toMillis()

if (now - lastViolation >= cooldownMillis && state.violationCount > 0) {
val newState = state.copy(
violationCount = state.violationCount - 1
)
savePenaltyState(identity, apiPath, apiMethod, newState, group.cooldownSeconds)
}
}

private fun getPenaltyState(
identity: String,
apiPath: String,
apiMethod: String
): RateLimitPenaltyState? {
return redis.get(buildPenaltyStateKey(identity, apiPath, apiMethod))
}

private fun savePenaltyState(
identity: String,
apiPath: String,
apiMethod: String,
state: RateLimitPenaltyState,
ttlSeconds: Int
) {
redis.put(
buildPenaltyStateKey(identity, apiPath, apiMethod),
state,
DynamicInterval(ttlSeconds, TimeUnit.SECONDS)
)
}

private fun buildPenaltyStateKey(identity: String, apiPath: String, apiMethod: String): String {
val key = "$identity:$apiMethod:$apiPath"
return "rl:penalty:${key.hashCode()}"
}
}
Loading
Loading