1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
| import com.nimbusds.jose.JOSEException import com.nimbusds.jose.JWSAlgorithm import com.nimbusds.jose.JWSHeader import com.nimbusds.jose.crypto.MACSigner import com.nimbusds.jose.crypto.MACVerifier import com.nimbusds.jwt.JWTClaimsSet import com.nimbusds.jwt.SignedJWT import com.zhengchalei.service.sys.repository.SysDictRepository import org.slf4j.LoggerFactory import org.springframework.boot.CommandLineRunner import org.springframework.security.core.Authentication import org.springframework.security.core.authority.SimpleGrantedAuthority import org.springframework.security.core.userdetails.User import org.springframework.stereotype.Component import java.util.*
@Component class JwtProvider(private val dictRepository: SysDictRepository) : CommandLineRunner {
private val logger = LoggerFactory.getLogger(JwtProvider::class.java)
private var secret: ByteArray = "zhengchalei.github.io".toByteArray(Charsets.UTF_8)
private var expiration: Long = 3600L
override fun run(vararg args: String) { val dict = this.dictRepository.findByCode("jwt") val dictItems = dict?.dictItems ?: emptyList() val secret = dictItems.find { it.code == "secret" } val expired = dictItems.find { it.code == "expired" }
when { secret == null -> { logger.error("JWT 参数初始化失败, 缺少 'secret', 正在使用默认值: 危险") this.secret = "default-secret".toByteArray(Charsets.UTF_8) }
secret.data.toByteArray(Charsets.UTF_8).size < 32 -> { logger.error("JWT 参数错误, 'secret' 必须大于 32 个字符, 正在使用默认值: 危险") }
else -> { this.secret = secret.data.toByteArray(Charsets.UTF_8) } }
when { expired == null -> { logger.warn("JWT 参数初始化失败, 缺少 'expired'") this.expiration = 3600L }
expired.data.toLong() < 0 -> { logger.warn("JWT 参数初始化错误, 'expired' 必须大于 0 , 正在使用默认值: 危险") }
else -> { this.expiration = expired.data.toLong() * 1000 } }
logger.info("jwt 参数初始化完毕") }
fun createToken(authentication: TenantCaptchaAuthenticationToken): String { val header = JWSHeader(JWSAlgorithm.HS256)
val claimsSet = JWTClaimsSet.Builder() .subject(authentication.name) .claim("username", authentication.name) .claim( "roles", authentication.authorities .filter { it.authority.startsWith("ROLE_") } .map { it.authority } .joinToString(",") ) .claim( "permissions", authentication.authorities .filter { !it.authority.startsWith("ROLE_") } .map { it.authority } .joinToString(",") ) .claim("tenant", authentication.tenant) .expirationTime(Date(System.currentTimeMillis() + expiration)) .build()
val signedJWT = SignedJWT(header, claimsSet) val signer = MACSigner(secret) signedJWT.sign(signer) return signedJWT.serialize() }
fun parseToken(token: String): SignedJWT { return SignedJWT.parse(token) }
fun validateToken(token: String): Boolean { val jwt = parseToken(token) val verifier = MACVerifier(secret) if (!jwt.verify(verifier)) { return false } val claimsSet = jwt.jwtClaimsSet return !claimsSet.expirationTime.before(Date()) }
fun getAuthentication(token: String): Authentication { val jwt: SignedJWT = parseToken(token) val jwtClaimsSet = jwt.jwtClaimsSet ?: throw RuntimeException("JWT ClaimsSet is null") val permissions = jwtClaimsSet.getStringClaim("permissions").split(",") val roles = jwtClaimsSet.getStringClaim("roles").split(",") val tenant = jwtClaimsSet.getStringClaim("tenant")
val authorities = mutableListOf(permissions, roles).flatten().map { SimpleGrantedAuthority(it) }
val principal = User(jwtClaimsSet.subject, "", authorities)
return TenantCaptchaAuthenticationToken( username = principal, password = "", tenant = tenant, captcha = "", authorities = authorities ) }
private fun getListClaimsSet(jwtClaimsSet: JWTClaimsSet, key: String): List<String> { return (jwtClaimsSet.getClaim(key) as List<*>).map { it as String } }
}
|