Spring-Security-Multi-Tenant-Test

本文最后更新于:几秒前

Spring-Security-Multi-Tenant-Test

需求: 多租户系统, 自定义 SpringSecurity 认证, 登录成功后返回用户信息, 在测试中, 可以使用MockUser 完成正常的获取租户信息

实现思路

  1. 自定义认证过滤器 TenantCaptchaAuthenticationFilter
  2. JWT Util, 获取Token
  3. 添加JWT拦截器, 过滤Token设置到Content中
  4. 设置自己 MockUser Annotation
  5. 完成 MockUser Annotation WithSecurityContextFactory

SpringSecurityConfig

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import com.fasterxml.jackson.databind.ObjectMapper
import com.zhengchalei.service.sys.config.GlobalException
import org.springframework.beans.factory.annotation.Value
import org.springframework.context.annotation.Bean
import org.springframework.security.authentication.AuthenticationManager
import org.springframework.security.config.annotation.authentication.configuration.AuthenticationConfiguration
import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity
import org.springframework.security.config.annotation.web.builders.HttpSecurity
import org.springframework.security.web.SecurityFilterChain
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter
import org.springframework.stereotype.Service


@Service
@EnableMethodSecurity(prePostEnabled = true, securedEnabled = true, jsr250Enabled = true)
class SpringSecurityConfig(
@Value("\${spring.profiles.active}")
private val profile: String,
private val objectMapper: ObjectMapper,
private val tenantCaptchaAuthenticationProvider: TenantCaptchaAuthenticationProvider
) {


@Bean
fun securityFilterChain(
http: HttpSecurity,
jwtProvider: JwtProvider,
authenticationManager: AuthenticationManager
): SecurityFilterChain {
val tenantCaptchaAuthenticationFilter = TenantCaptchaAuthenticationFilter()
tenantCaptchaAuthenticationFilter.setAuthenticationManager(authenticationManager)
return http
.addFilterBefore(
JwtConfigurer.JwtAuthorizationFilter(jwtProvider),
UsernamePasswordAuthenticationFilter::class.java
)
.addFilterBefore(tenantCaptchaAuthenticationFilter, UsernamePasswordAuthenticationFilter::class.java)
.authenticationProvider(tenantCaptchaAuthenticationProvider)
.authorizeHttpRequests { authorize ->

// favicon.ico
authorize.requestMatchers("/favicon.ico").permitAll()

authorize.requestMatchers("/api/auth/login").permitAll()
authorize.requestMatchers("/api/auth/register").permitAll()
authorize.requestMatchers("/api/auth/captcha").permitAll()
// dev
if (profile == "dev" || profile == "test") {
authorize.requestMatchers("/openapi.html").permitAll()
authorize.requestMatchers("/openapi.yml").permitAll()
}
authorize.anyRequest().authenticated()
}
.exceptionHandling {
it.authenticationEntryPoint { request, response, authException ->
response.sendError(
401,
objectMapper.writeValueAsString(GlobalException.Error(authException.message ?: "未登录"))
)
}
it.accessDeniedHandler { request, response, accessDeniedException ->
response.sendError(
403,
objectMapper.writeValueAsString(
GlobalException.Error(
accessDeniedException.message ?: "无权限"
)
)
)
}
}
.sessionManagement {
it.disable()
}
.csrf {
it.disable()
}
.build()
}

@Bean
fun authenticationManager(configuration: AuthenticationConfiguration): AuthenticationManager {
return configuration.authenticationManager
}

}

TenantCaptchaAuthenticationFilter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.springframework.security.core.Authentication
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter

class TenantCaptchaAuthenticationFilter : UsernamePasswordAuthenticationFilter() {
override fun attemptAuthentication(request: HttpServletRequest, response: HttpServletResponse): Authentication {
val username = request.getParameter("username")
val password = request.getParameter("password")
val tenantId = request.getParameter("tenant")
val captcha = request.getParameter("captcha")

val authRequest = TenantCaptchaAuthenticationToken(
username, password, tenantId, captcha
)
setDetails(request, authRequest)
return authenticationManager.authenticate(authRequest)
}
}

TenantCaptchaAuthenticationProvider

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

import com.zhengchalei.service.sys.config.*
import com.zhengchalei.service.sys.domain.SysTenant
import com.zhengchalei.service.sys.domain.code
import com.zhengchalei.service.sys.domain.id
import com.zhengchalei.service.sys.repository.SysTenantRepository
import com.zhengchalei.service.sys.repository.SysUserRepository
import org.babyfish.jimmer.sql.kt.ast.expression.eq
import org.springframework.security.authentication.AuthenticationProvider
import org.springframework.security.core.Authentication
import org.springframework.security.core.GrantedAuthority
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.security.core.userdetails.User
import org.springframework.security.core.userdetails.UserDetails
import org.springframework.security.crypto.password.PasswordEncoder
import org.springframework.stereotype.Component

@Component
class TenantCaptchaAuthenticationProvider(
val sysTenantRepository: SysTenantRepository,
val sysUserRepository: SysUserRepository,
val passwordEncoder: PasswordEncoder,
) : AuthenticationProvider {

override fun authenticate(authentication: Authentication): Authentication {
val username = authentication.name
val password = authentication.credentials as String
val tenant = (authentication as TenantCaptchaAuthenticationToken).tenant
val captcha = authentication.captcha
// 验证租户ID、验证码和用户密码的逻辑
if (isValidTenant(tenant) && isValidCaptcha(captcha)) {
val userDetails = loadUserByUsername(username, password, tenant)
return TenantCaptchaAuthenticationToken(username, password, tenant, captcha, userDetails.authorities)
} else {
throw UserPasswordErrorException()
}
}

override fun supports(authentication: Class<*>): Boolean {
return TenantCaptchaAuthenticationToken::class.java.isAssignableFrom(authentication)
}

private fun isValidTenant(tenant: String): Boolean {
// 这里实现租户ID验证逻辑
this.sysTenantRepository.sql.createQuery(SysTenant::class) {
where(table.code eq tenant)
select(table.id)
}.fetchOneOrNull() ?: throw TenantNotFoundException()
return true
}

private fun isValidCaptcha(captcha: String): Boolean {
// TODO 这里实现验证码验证逻辑
return true ?: throw CaptchaErrorException()
}

fun loadUserByUsername(username: String, password: String, tenant: String): UserDetails {
val user = sysUserRepository.findByUsernameAndTenant(username, tenant) ?: throw UserNotFoundException()
if (!user.status) throw UserDisabledException()
if (!passwordEncoder.matches(password, user.password)) throw UserPasswordErrorException()
val authorityList = mutableListOf<GrantedAuthority>()
val roles = user.roles
val permissions = user.roles.flatMap { it.permissions }
authorityList.addAll(permissions.map { it.code }.map { SimpleGrantedAuthority(it) })
authorityList.addAll(roles.map { it.code }.map { "ROLE_$it" }.map { SimpleGrantedAuthority(it) })
return User(username, user.password, user.status, true, true, true, authorityList)
}
}

TenantCaptchaAuthenticationToken

1
2
3
4
5
6
7
8
9
10
11
12
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken
import org.springframework.security.core.GrantedAuthority

class TenantCaptchaAuthenticationToken(
username: Any,
password: Any,
val tenant: String,
val captcha: String,
authorities: Collection<GrantedAuthority> = emptyList()
) :
UsernamePasswordAuthenticationToken(username, password, authorities)

JwtTokenProvider

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

import jakarta.servlet.FilterChain
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.springframework.security.config.annotation.SecurityConfigurerAdapter
import org.springframework.security.config.annotation.web.builders.HttpSecurity
import org.springframework.security.core.Authentication
import org.springframework.security.core.context.SecurityContextHolder
import org.springframework.security.web.DefaultSecurityFilterChain
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter
import org.springframework.web.filter.OncePerRequestFilter

class JwtConfigurer(private val jwtProvider: JwtProvider) :
SecurityConfigurerAdapter<DefaultSecurityFilterChain, HttpSecurity>() {
override fun configure(http: HttpSecurity) {
// 实例化 JwtFilter 拦截器, 将Token util bean 传递过来
val customFilter = JwtAuthorizationFilter(jwtProvider)
// 将这个 jwt filter 配置在 UsernamePasswordAuthenticationFilter.class 之前
http.addFilterBefore(customFilter, UsernamePasswordAuthenticationFilter::class.java)
}


class JwtAuthorizationFilter(private val jwtProvider: JwtProvider) : OncePerRequestFilter() {
private val AUTHORIZATION: String = "Authorization"
private val ACCESS_TOKEN: String = "access_token"
override fun doFilterInternal(
request: HttpServletRequest,
response: HttpServletResponse,
filterChain: FilterChain
) {
// 这里就是获取到token
val jwt: String? = resolveToken(request)
// 如果 jwt不为空 然后调用了 jwtUtil 效验了 token 是否有效
if (!jwt.isNullOrBlank() && jwtProvider.validateToken(jwt)) {
// 获取 Authentication
val authentication: Authentication = jwtProvider.getAuthentication(jwt)
// 将 认证信息重新set 到 security context 中
SecurityContextHolder.getContext().authentication = authentication
}
// 拦截器继续执行,
filterChain.doFilter(request, response)
}

private fun resolveToken(request: HttpServletRequest): String? {
// 从头部信息拿到 Authorization 的内容
val bearerToken = request.getHeader(AUTHORIZATION)
// 如果 不为空, 且 Bearer 开头
if (!bearerToken.isNullOrBlank() && bearerToken.startsWith("Bearer ")) {
// 这里的 7 长度就是 "Bearer " 的长度
return bearerToken.substring(7)
}

val accessToken = request.getParameter(ACCESS_TOKEN)
if (!accessToken.isNullOrBlank()) {
return accessToken
}
return null
}

}
}

JwtProvider

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

import com.nimbusds.jose.JOSEException
import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.JWSHeader
import com.nimbusds.jose.crypto.MACSigner
import com.nimbusds.jose.crypto.MACVerifier
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import com.zhengchalei.service.sys.repository.SysDictRepository
import org.slf4j.LoggerFactory
import org.springframework.boot.CommandLineRunner
import org.springframework.security.core.Authentication
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.security.core.userdetails.User
import org.springframework.stereotype.Component
import java.util.*


@Component
class JwtProvider(private val dictRepository: SysDictRepository) : CommandLineRunner {

private val logger = LoggerFactory.getLogger(JwtProvider::class.java)

private var secret: ByteArray = "zhengchalei.github.io".toByteArray(Charsets.UTF_8)

private var expiration: Long = 3600L

override fun run(vararg args: String) {
val dict = this.dictRepository.findByCode("jwt")
val dictItems = dict?.dictItems ?: emptyList()
// 获取 secret
val secret = dictItems.find { it.code == "secret" }
// 获取 expired
val expired = dictItems.find { it.code == "expired" }

when {
secret == null -> {
logger.error("JWT 参数初始化失败, 缺少 'secret', 正在使用默认值: 危险")
this.secret = "default-secret".toByteArray(Charsets.UTF_8)
}

secret.data.toByteArray(Charsets.UTF_8).size < 32 -> {
logger.error("JWT 参数错误, 'secret' 必须大于 32 个字符, 正在使用默认值: 危险")
}

else -> {
this.secret = secret.data.toByteArray(Charsets.UTF_8)
}
}

when {
expired == null -> {
logger.warn("JWT 参数初始化失败, 缺少 'expired'")
this.expiration = 3600L // 默认过期时间为 1 小时
}

expired.data.toLong() < 0 -> {
logger.warn("JWT 参数初始化错误, 'expired' 必须大于 0 , 正在使用默认值: 危险")
}

else -> {
this.expiration = expired.data.toLong() * 1000 // 转换为毫秒
}
}

logger.info("jwt 参数初始化完毕")
}

/**
* 创建一个 JWT。
*
* @return 签名后的 JWT 对象
* @throws JOSEException 如果签名过程中出现错误
*/
fun createToken(authentication: TenantCaptchaAuthenticationToken): String {
// Header
val header = JWSHeader(JWSAlgorithm.HS256)

// Payload
val claimsSet = JWTClaimsSet.Builder()
.subject(authentication.name)
.claim("username", authentication.name)
.claim(
"roles",
authentication.authorities
.filter { it.authority.startsWith("ROLE_") }
.map { it.authority }
.joinToString(",")
)
.claim(
"permissions",
authentication.authorities
.filter { !it.authority.startsWith("ROLE_") }
.map { it.authority }
.joinToString(",")
)
.claim("tenant", authentication.tenant)
.expirationTime(Date(System.currentTimeMillis() + expiration)) // 1 hour from now
.build()

val signedJWT = SignedJWT(header, claimsSet)
val signer = MACSigner(secret)
signedJWT.sign(signer)
return signedJWT.serialize()
}

/**
* 解析一个 JWT 字符串。
*
* @param token JWT 字符串
* @return 解析后的 SignedJWT 对象
*/
fun parseToken(token: String): SignedJWT {
return SignedJWT.parse(token)
}

/**
* 验证 JWT 是否有效。
*
* @param token 签名后的 JWT 对象
* @return 如果 JWT 有效则返回 true,否则返回 false
*/
fun validateToken(token: String): Boolean {
val jwt = parseToken(token)
val verifier = MACVerifier(secret)
if (!jwt.verify(verifier)) {
return false
}
val claimsSet = jwt.jwtClaimsSet
return !claimsSet.expirationTime.before(Date())
}

fun getAuthentication(token: String): Authentication {
val jwt: SignedJWT = parseToken(token)
val jwtClaimsSet = jwt.jwtClaimsSet ?: throw RuntimeException("JWT ClaimsSet is null")
val permissions = jwtClaimsSet.getStringClaim("permissions").split(",")
val roles = jwtClaimsSet.getStringClaim("roles").split(",")
val tenant = jwtClaimsSet.getStringClaim("tenant")

// 构建权限
val authorities = mutableListOf(permissions, roles).flatten().map { SimpleGrantedAuthority(it) }

val principal = User(jwtClaimsSet.subject, "", authorities)

return TenantCaptchaAuthenticationToken(
username = principal,
password = "",
tenant = tenant,
captcha = "",
authorities = authorities
)
}

private fun getListClaimsSet(jwtClaimsSet: JWTClaimsSet, key: String): List<String> {
return (jwtClaimsSet.getClaim(key) as List<*>).map { it as String }
}

}

Test

WithMockTenantUser

1
2
3
4
5
6
7
8
9
10
11
12
13
import org.springframework.security.test.context.support.WithSecurityContext

@Retention(AnnotationRetention.RUNTIME)
@WithSecurityContext(factory = WithMockTenantUserSecurityContextFactory::class)
annotation class WithMockTenantUser(
val username: String = "admin",
val name: String = "超级管理员",
val password: String = "admin",
val roles: Array<String> = ["ADMIN"],
val authorities: Array<String> = [],
val tenant: String = "default"
)

WithMockTenantUserSecurityContextFactory
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import com.zhengchalei.service.sys.config.security.TenantCaptchaAuthenticationToken
import org.springframework.security.core.Authentication
import org.springframework.security.core.GrantedAuthority
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.security.core.context.SecurityContext
import org.springframework.security.core.context.SecurityContextHolder
import org.springframework.security.core.userdetails.User
import org.springframework.security.core.userdetails.UserDetails
import org.springframework.security.test.context.support.WithSecurityContextFactory
import java.util.*

class WithMockTenantUserSecurityContextFactory

: WithSecurityContextFactory<WithMockTenantUser> {
override fun createSecurityContext(annotation: WithMockTenantUser): SecurityContext {
val context = SecurityContextHolder.createEmptyContext()
val authorities: MutableList<GrantedAuthority> = ArrayList()
authorities.addAll(Arrays.stream(annotation.authorities)
.map { SimpleGrantedAuthority(it) }
.toList())
authorities.addAll(Arrays.stream(annotation.roles)
.map { "ROLE_$it" }
.map { SimpleGrantedAuthority(it) }
.toList())
val principal: UserDetails = User(annotation.username, annotation.password, authorities)
val auth: Authentication = TenantCaptchaAuthenticationToken(principal, annotation.password, annotation.tenant, "", authorities)
context.authentication = auth
return context
}
}

使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@SpringBootTest
@AutoConfigureMockMvc
@TestMethodOrder(OrderAnnotation::class)
@WithMockTenantUser(
username = "admin",
authorities = ["sys:department:create", "sys:department:update", "sys:department:delete", "sys:department:list", "sys:department:page", "sys:department:tree", "sys:department:tree-root", "sys:department:id"]
)
class SysDepartmentControllerTest {

@Autowired
lateinit var mockMvc: MockMvc

@Autowired
lateinit var objectMapper: ObjectMapper

@Order(Integer.MIN_VALUE)
@Test
fun create() {
mockMvc.post("/api/sys/department/create") {
content = objectMapper.writeValueAsString(
SysDepartmentCreateInput(
name = "测试部门",
description = "test",
sort = 1,
status = true,
parentId = 1
)
)
contentType = MediaType.APPLICATION_JSON
}
.andExpect {
status { isOk() }
content {
jsonPath("$.success") {
exists()
value(true)
}
}
}
}
}

Spring-Security-Multi-Tenant-Test
https://zhengchalei.github.io/2024/08/16/Spring-Security-Multi-Tenant-Test/
作者
ZhengChaLei
发布于
2024年8月16日
许可协议