1 module jwtlited.gnutls; 2 3 public import jwtlited; 4 version (assert) import core.stdc.stdio; 5 import core.stdc.string : memcpy; 6 import bindbc.gnutls; 7 8 version (DYNAMIC_GNUTLS) 9 { 10 shared static this() 11 { 12 import core.stdc.stdio; 13 import loader = bindbc.loader.sharedlib; 14 auto res = loadGnuTLS(); 15 if (res < GnuTLSSupport.gnutls_3_5_0) 16 { 17 fprintf(stderr, "Error loading GnuTLS: %d\n", res); 18 foreach(info; loader.errors) 19 { 20 fprintf(stderr, "\t%s: %s\n", info.error, info.message); 21 } 22 assert(0, "Error loading GnuTLS"); 23 } 24 } 25 } 26 27 alias HS256Handler = HMACImpl!(JWTAlgorithm.HS256); 28 alias HS384Handler = HMACImpl!(JWTAlgorithm.HS384); 29 alias HS512Handler = HMACImpl!(JWTAlgorithm.HS512); 30 31 /** 32 * Implementation of HS256, HS384 and HS512 signing algorithms. 33 */ 34 private struct HMACImpl(JWTAlgorithm implAlg) 35 { 36 static if (implAlg == JWTAlgorithm.HS256) enum signLen = 32; 37 else static if (implAlg == JWTAlgorithm.HS384) enum signLen = 48; 38 else static if (implAlg == JWTAlgorithm.HS512) enum signLen = 64; 39 else static assert(0, "Unsupprted algorithm for HMAC implementation"); 40 41 private 42 { 43 const(char)[] key; 44 gnutls_hmac_hd_t ctx; 45 ubyte[signLen] sigBuf; 46 } 47 48 @disable this(this); 49 50 ~this() @trusted 51 { 52 if (ctx) gnutls_hmac_deinit(ctx, null); 53 } 54 55 bool loadKey(K)(K key) if (isToken!K) 56 { 57 if (!key.length) return false; 58 this.key = cast(const(char)[])key; 59 60 immutable ret = () @trusted 61 { 62 static if (implAlg == JWTAlgorithm.HS256) alias alg = gnutls_mac_algorithm_t.GNUTLS_MAC_SHA256; 63 else static if (implAlg == JWTAlgorithm.HS384) alias alg = gnutls_mac_algorithm_t.GNUTLS_MAC_SHA384; 64 else static if (implAlg == JWTAlgorithm.HS512) alias alg = gnutls_mac_algorithm_t.GNUTLS_MAC_SHA512; 65 66 assert(signLen == gnutls_hmac_get_len(alg)); 67 return gnutls_hmac_init(&ctx, alg, this.key.ptr, key.length); 68 }(); 69 assert(ctx); 70 return ret == 0; 71 } 72 73 bool isValidAlg(JWTAlgorithm alg) { return implAlg == alg; } 74 75 bool isValid(V, S)(V value, S sign) if (isToken!V && isToken!S) 76 { 77 if (!genSignature(value)) return false; 78 return cast(const(ubyte)[])sign == sigBuf[0..signLen]; 79 } 80 81 JWTAlgorithm signAlg() { return implAlg; } 82 83 int sign(S, V)(auto ref S sink, auto ref V value) if (isToken!V) 84 { 85 import std.range : put; 86 if (!genSignature(value)) return -1; 87 put(sink, sigBuf[0..signLen]); 88 return signLen; 89 } 90 91 private bool genSignature(V)(V value) @trusted 92 { 93 assert(key.length, "Secret key not set"); 94 if (!key.length || !value.length) return false; 95 96 auto ret = gnutls_hmac(ctx, value.ptr, value.length); 97 if (ret < 0) return false; 98 99 gnutls_hmac_output(ctx, sigBuf.ptr); 100 return true; 101 } 102 } 103 104 alias RS256Handler = PEMImpl!(JWTAlgorithm.RS256); 105 alias RS384Handler = PEMImpl!(JWTAlgorithm.RS384); 106 alias RS512Handler = PEMImpl!(JWTAlgorithm.RS512); 107 alias ES256Handler = PEMImpl!(JWTAlgorithm.ES256); 108 alias ES384Handler = PEMImpl!(JWTAlgorithm.ES384); 109 alias ES512Handler = PEMImpl!(JWTAlgorithm.ES512); 110 111 /** 112 * Implementation of ES256, ES384 and ES512 signing algorithms. 113 */ 114 private struct PEMImpl(JWTAlgorithm implAlg) 115 { 116 private 117 { 118 gnutls_x509_privkey_t x509key; 119 gnutls_privkey_t privKey; 120 gnutls_pubkey_t pubKey; 121 122 import std.algorithm : among; 123 static if (implAlg.among(JWTAlgorithm.ES256, JWTAlgorithm.ES384, JWTAlgorithm.ES512)) 124 enum pkAlg = gnutls_pk_algorithm_t.GNUTLS_PK_ECDSA; 125 else static if (implAlg.among(JWTAlgorithm.RS256, JWTAlgorithm.RS384, JWTAlgorithm.RS512)) 126 enum pkAlg = gnutls_pk_algorithm_t.GNUTLS_PK_RSA; 127 else static assert(0, "Unsupprted algorithm for PEM implementation"); 128 129 static if (implAlg == JWTAlgorithm.ES256) alias alg = gnutls_sign_algorithm_t.GNUTLS_SIGN_ECDSA_SHA256; 130 else static if (implAlg == JWTAlgorithm.ES384) alias alg = gnutls_sign_algorithm_t.GNUTLS_SIGN_ECDSA_SHA384; 131 else static if (implAlg == JWTAlgorithm.ES512) alias alg = gnutls_sign_algorithm_t.GNUTLS_SIGN_ECDSA_SHA512; 132 else static if (implAlg == JWTAlgorithm.RS256) alias alg = gnutls_sign_algorithm_t.GNUTLS_SIGN_RSA_SHA256; 133 else static if (implAlg == JWTAlgorithm.RS384) alias alg = gnutls_sign_algorithm_t.GNUTLS_SIGN_RSA_SHA384; 134 else static if (implAlg == JWTAlgorithm.RS512) alias alg = gnutls_sign_algorithm_t.GNUTLS_SIGN_RSA_SHA512; 135 } 136 137 @disable this(this); 138 139 ~this() @trusted 140 { 141 if (x509key) gnutls_x509_privkey_deinit(x509key); 142 if (privKey) gnutls_privkey_deinit(privKey); 143 if (pubKey) gnutls_pubkey_deinit(pubKey); 144 } 145 146 bool loadKey(K)(K key) @trusted if (isToken!K) 147 { 148 if (!key.length) return false; 149 150 if (gnutls_pubkey_init(&pubKey)) return false; 151 152 gnutls_datum_t cert_dat = gnutls_datum_t(cast(ubyte*)key.ptr, cast(uint)key.length); 153 if (gnutls_pubkey_import(pubKey, &cert_dat, gnutls_x509_crt_fmt_t.GNUTLS_X509_FMT_PEM)) 154 { 155 gnutls_pubkey_deinit(pubKey); 156 pubKey = null; 157 return false; 158 } 159 160 return true; 161 } 162 163 bool loadPKey(K)(K key) @trusted if (isToken!K) 164 { 165 if (gnutls_x509_privkey_init(&x509key)) return false; 166 167 gnutls_datum_t keyData = gnutls_datum_t(cast(ubyte*)key.ptr, cast(uint)key.length); 168 if (gnutls_x509_privkey_import(x509key, &keyData, gnutls_x509_crt_fmt_t.GNUTLS_X509_FMT_PEM)) 169 goto err; 170 171 if (gnutls_privkey_init(&privKey)) goto err; 172 if (gnutls_privkey_import_x509(privKey, x509key, 0)) goto err; 173 if (pkAlg != gnutls_privkey_get_pk_algorithm(privKey, null)) goto err; 174 175 return true; 176 177 err: 178 if (x509key) { gnutls_x509_privkey_deinit(x509key); x509key = null; } 179 if (privKey) { gnutls_privkey_deinit(privKey); privKey = null; } 180 return false; 181 } 182 183 bool isValidAlg(JWTAlgorithm alg) { return implAlg == alg; } 184 185 bool isValid(V, S)(V value, S sign) @trusted if (isToken!V && isToken!S) 186 { 187 version (unittest) {} // no assert behavior is tested in unittest 188 else assert(pubKey, "Public key not set"); 189 if (!value.length || !sign.length || !pubKey) return false; 190 191 gnutls_datum_t data = gnutls_datum_t(cast(ubyte*)value.ptr, cast(uint)value.length); 192 static if (pkAlg == gnutls_pk_algorithm_t.GNUTLS_PK_RSA) 193 { 194 gnutls_datum_t sig_dat = gnutls_datum_t(sign.ptr, cast(uint)sign.length); 195 if (gnutls_pubkey_verify_data2(pubKey, alg, 0, &data, &sig_dat)) 196 return false; 197 } 198 else 199 { 200 // Rebuild signature using r and s extracted from sig 201 202 gnutls_datum_t r, s; 203 static if (implAlg == JWTAlgorithm.ES256) 204 { 205 if (sign.length != 64) return false; 206 r.size = 32; 207 r.data = sign.ptr; 208 s.size = 32; 209 s.data = sign.ptr + 32; 210 } 211 else static if (implAlg == JWTAlgorithm.ES384) 212 { 213 if (sign.length != 96) return false; 214 r.size = 48; 215 r.data = sign.ptr; 216 s.size = 48; 217 s.data = sign.ptr + 48; 218 } 219 else static if (implAlg == JWTAlgorithm.ES512) 220 { 221 if (sign.length != 132) return false; 222 r.size = 66; 223 r.data = sign.ptr; 224 s.size = 66; 225 s.data = sign.ptr + 66; 226 } 227 else static assert(0); 228 229 gnutls_datum_t sig_dat; 230 scope (exit) 231 { 232 if (sig_dat.data) gnutls_free(sig_dat.data); 233 } 234 235 if (gnutls_encode_rs_value(&sig_dat, &r, &s)) return false; 236 if (gnutls_pubkey_verify_data2(pubKey, alg, 0, &data, &sig_dat)) return false; 237 } 238 return true; 239 } 240 241 JWTAlgorithm signAlg() { return implAlg; } 242 243 int sign(S, V)(auto ref S sink, auto ref V value) @trusted 244 { 245 import std.range : put; 246 247 version (unittest) {} // no assert behavior is tested in unittest 248 else assert(privKey, "Private key not set"); 249 if (!value.length || !privKey) return -1; 250 251 gnutls_datum_t body_dat = gnutls_datum_t(cast(ubyte*)value.ptr, cast(uint)value.length); 252 gnutls_datum_t sig_dat; 253 254 immutable ret = gnutls_privkey_sign_data2(privKey, alg, 0, &body_dat, &sig_dat); 255 if (ret) return -1; 256 scope (exit) gnutls_free(sig_dat.data); 257 258 static if (pkAlg == gnutls_pk_algorithm_t.GNUTLS_PK_RSA) 259 { 260 put(sink, sig_dat.data[0..sig_dat.size]); 261 return sig_dat.size; 262 } 263 else 264 { 265 gnutls_datum_t r, s; 266 if (gnutls_decode_rs_value(&sig_dat, &r, &s)) return -1; 267 scope (exit) 268 { 269 gnutls_free(r.data); 270 gnutls_free(s.data); 271 } 272 273 static if (implAlg == JWTAlgorithm.ES256) enum adj = 32; 274 else static if (implAlg == JWTAlgorithm.ES384) enum adj = 48; 275 else static if (implAlg == JWTAlgorithm.ES512) enum adj = 66; 276 277 int r_padding, s_padding, r_out_padding, s_out_padding; 278 size_t out_size; 279 280 if (r.size > adj) r_padding = r.size - adj; 281 else if (r.size < adj) r_out_padding = adj - r.size; 282 283 if (s.size > adj) s_padding = s.size - adj; 284 else if (s.size < adj) s_out_padding = adj - s.size; 285 286 out_size = adj << 1; 287 ubyte[512] buf; 288 assert(buf.length >= out_size); 289 290 memcpy(buf.ptr + r_out_padding, r.data + r_padding, r.size - r_padding); 291 memcpy( 292 buf.ptr + (r.size - r_padding + r_out_padding) + s_out_padding, 293 s.data + s_padding, 294 s.size - s_padding 295 ); 296 297 assert((r.size - r_padding + r_out_padding) + (s.size - s_padding + s_out_padding) == out_size); 298 put(sink, buf[0..out_size]); 299 return (r.size - r_padding + r_out_padding) + (s.size - s_padding + s_out_padding); 300 } 301 } 302 } 303 304 version (unittest) import jwtlited.tests; 305 306 @("GnuTLS tests") 307 @safe unittest 308 { 309 static void eval(H)(ref immutable TestCase tc) 310 { 311 H h; 312 static if (is(H == HS256Handler) || is(H == HS384Handler) || is(H == HS512Handler)) 313 assert(h.loadKey(tc.key) == !!(tc.valid & Valid.key)); 314 else 315 { 316 if (tc.test & Test.decode) 317 assert(h.loadKey(tc.key) == !!(tc.valid & Valid.key)); 318 if (tc.test & Test.encode) 319 assert(h.loadPKey(tc.pkey) == !!(tc.valid & Valid.key)); 320 } 321 322 evalTest(h, tc); 323 } 324 325 static auto allocatedInCurrentThread() @trusted 326 { 327 import core.memory : GC; 328 static if (__VERSION__ >= 2094) return GC.allocatedInCurrentThread(); 329 else return GC.stats().allocatedInCurrentThread; 330 } 331 332 import std.algorithm : canFind, filter; 333 334 immutable pre = allocatedInCurrentThread(); 335 336 with (JWTAlgorithm) 337 { 338 static immutable testAlgs = [ 339 HS256, HS384, HS512, 340 RS256, RS384, RS512, 341 ES256, ES384, ES512 342 ]; 343 344 foreach (tc; testCases.filter!(a => testAlgs.canFind(a.alg))) 345 { 346 final switch (tc.alg) 347 { 348 case none: assert(0); 349 case HS256: eval!HS256Handler(tc); break; 350 case HS384: eval!HS384Handler(tc); break; 351 case HS512: eval!HS512Handler(tc); break; 352 353 case RS256: eval!RS256Handler(tc); break; 354 case RS384: eval!RS384Handler(tc); break; 355 case RS512: eval!RS512Handler(tc); break; 356 357 case ES256: eval!ES256Handler(tc); break; 358 case ES384: eval!ES384Handler(tc); break; 359 case ES512: eval!ES512Handler(tc); break; 360 } 361 } 362 } 363 364 assert(allocatedInCurrentThread() - pre == 0); // check for no GC allocations 365 } 366