1 module jwtlited.openssl;
2 
3 public import jwtlited;
4 version (assert) import core.stdc.stdio;
5 
6 import deimos.openssl.ec;
7 import deimos.openssl.err;
8 import deimos.openssl.evp;
9 import deimos.openssl.hmac;
10 import deimos.openssl.pem;
11 import deimos.openssl.sha;
12 
13 import core.exception : onOutOfMemoryError;
14 
15 // some missing OpenSSL symbols
16 extern(C) nothrow @nogc
17 {
18     void EVP_MD_CTX_free(EVP_MD_CTX* ctx);
19     EVP_MD_CTX* EVP_MD_CTX_new();
20     int ECDSA_SIG_set0(ECDSA_SIG* sig, BIGNUM* r, BIGNUM* s);
21     void ECDSA_SIG_get0(const(ECDSA_SIG)* sig, const BIGNUM** pr, const BIGNUM** ps);
22 }
23 
24 alias HS256Handler = HMACImpl!(JWTAlgorithm.HS256);
25 alias HS384Handler = HMACImpl!(JWTAlgorithm.HS384);
26 alias HS512Handler = HMACImpl!(JWTAlgorithm.HS512);
27 
28 /**
29  * Implementation of HS256, HS384 and HS512 signing algorithms.
30  */
31 private struct HMACImpl(JWTAlgorithm implAlg)
32 {
33     static if (implAlg == JWTAlgorithm.HS256) enum signLen = SHA256_DIGEST_LENGTH;
34     else static if (implAlg == JWTAlgorithm.HS384) enum signLen = SHA384_DIGEST_LENGTH;
35     else static if (implAlg == JWTAlgorithm.HS512) enum signLen = SHA512_DIGEST_LENGTH;
36     else static assert(0, "Unsupprted algorithm for HMAC implementation");
37 
38     private
39     {
40         const(char)[] key;
41         HMAC_CTX ctx;
42         ubyte[signLen] sigBuf;
43     }
44 
45     @disable this(this);
46 
47     ~this() @trusted
48     {
49         HMAC_CTX_reset(&ctx);
50     }
51 
52     bool loadKey(K)(K key) if (isToken!K)
53     {
54         if (!key.length) return false;
55         this.key = cast(const(char)[])key;
56 
57         auto ret = () @trusted
58         {
59             static if (implAlg == JWTAlgorithm.HS256) alias evp = EVP_sha256;
60             else static if (implAlg == JWTAlgorithm.HS384) alias evp = EVP_sha384;
61             else static if (implAlg == JWTAlgorithm.HS512) alias evp = EVP_sha512;
62 
63             HMAC_CTX_reset(&ctx);
64             return HMAC_Init_ex(&ctx, this.key.ptr, cast(int)key.length, evp(), null);
65         }();
66         if (!ret) return false;
67         return true;
68     }
69 
70     bool isValidAlg(JWTAlgorithm alg) { return implAlg == alg; }
71 
72     bool isValid(V, S)(V value, S sign) if (isToken!V && isToken!S)
73     {
74         if (!genSignature(value)) return false;
75         return cast(const(ubyte)[])sign == sigBuf[0..signLen];
76     }
77 
78     JWTAlgorithm signAlg() { return implAlg; }
79 
80     int sign(S, V)(auto ref S sink, auto ref V value) if (isToken!V)
81     {
82         import std.range : put;
83         if (!genSignature(value)) return -1;
84         put(sink, sigBuf[0..signLen]);
85         return signLen;
86     }
87 
88     private bool genSignature(V)(V value) @trusted
89     {
90         assert(key.length, "Secret key not set");
91         if (!key.length || !value.length) return false;
92 
93         scope (exit) HMAC_Init_ex(&ctx, null, 0, null, null);
94 
95         auto ret = HMAC_Update(&ctx, cast(const(ubyte)*)value.ptr, cast(ulong)value.length);
96         if (!ret) return false;
97 
98         uint slen;
99         ret = HMAC_Final(&ctx, sigBuf.ptr, &slen);
100         assert(slen == signLen);
101         if (!ret) return false;
102         return true;
103     }
104 }
105 
106 alias RS256Handler = PEMImpl!(JWTAlgorithm.RS256);
107 alias RS384Handler = PEMImpl!(JWTAlgorithm.RS384);
108 alias RS512Handler = PEMImpl!(JWTAlgorithm.RS512);
109 alias ES256Handler = PEMImpl!(JWTAlgorithm.ES256);
110 alias ES384Handler = PEMImpl!(JWTAlgorithm.ES384);
111 alias ES512Handler = PEMImpl!(JWTAlgorithm.ES512);
112 
113 /**
114  * Implementation of ES256, ES384 and ES512 signing algorithms.
115  */
116 private struct PEMImpl(JWTAlgorithm implAlg)
117 {
118     private
119     {
120         EVP_PKEY* pubKey;
121         EVP_PKEY* privKey;
122         EVP_MD_CTX* mdctxPriv;
123         EVP_MD_CTX* mdctxPub;
124 
125         import std.algorithm : among;
126         static if (implAlg.among(JWTAlgorithm.ES256, JWTAlgorithm.ES384, JWTAlgorithm.ES512))
127         {
128             enum type = EVP_PKEY_EC;
129             int slen;
130         }
131         else static if (implAlg.among(JWTAlgorithm.RS256, JWTAlgorithm.RS384, JWTAlgorithm.RS512))
132             enum type = EVP_PKEY_RSA;
133         else static assert(0, "Unsupprted algorithm for PEM implementation");
134 
135         static if (implAlg == JWTAlgorithm.ES256) alias evp = EVP_sha256;
136         else static if (implAlg == JWTAlgorithm.ES384) alias evp = EVP_sha384;
137         else static if (implAlg == JWTAlgorithm.ES512) alias evp = EVP_sha512;
138         else static if (implAlg == JWTAlgorithm.RS256) alias evp = EVP_sha256;
139         else static if (implAlg == JWTAlgorithm.RS384) alias evp = EVP_sha384;
140         else static if (implAlg == JWTAlgorithm.RS512) alias evp = EVP_sha512;
141     }
142 
143     @disable this(this);
144 
145     ~this() @trusted
146     {
147         if (pubKey) EVP_PKEY_free(pubKey);
148         if (privKey) EVP_PKEY_free(privKey);
149         if (mdctxPub) EVP_MD_CTX_free(mdctxPub);
150         if (mdctxPriv) EVP_MD_CTX_free(mdctxPriv);
151     }
152 
153     bool loadKey(K)(K key) @trusted if (isToken!K)
154     {
155         if (!key.length) return false;
156 
157         BIO* bpo = BIO_new_mem_buf(cast(char*)key.ptr, cast(int)key.length);
158         if (!bpo) onOutOfMemoryError;
159         scope (exit) BIO_free(bpo);
160 
161         // TODO: Uses OpenSSL's default passphrase callbacks if needed.
162         pubKey = PEM_read_bio_PUBKEY(bpo, null, null, null);
163         if (!pubKey)
164         {
165             version (assert) ERR_print_errors_fp(stderr);
166             return false;
167         }
168 
169         auto pkeyType = EVP_PKEY_id(pubKey);
170         if (pkeyType != type) return false;
171 
172         // Convert EC sigs back to ASN1.
173         static if (type == EVP_PKEY_EC)
174         {
175             // Get the actual ec_key
176             auto ec_key = EVP_PKEY_get1_EC_KEY(pubKey);
177             if (!ec_key) onOutOfMemoryError();
178             immutable degree = EC_GROUP_get_degree(EC_KEY_get0_group(ec_key));
179             EC_KEY_free(ec_key);
180 
181             immutable bn_len = (degree + 7) / 8;
182             slen = bn_len * 2;
183         }
184 
185         mdctxPub = EVP_MD_CTX_new();
186         if (!mdctxPub) onOutOfMemoryError();
187 
188         return true;
189     }
190 
191     bool loadPKey(K)(K key) @trusted if (isToken!K)
192     {
193         if (!key.length) return false;
194 
195         BIO* bpo = BIO_new_mem_buf(cast(char*)key.ptr, cast(int)key.length);
196         if (!bpo) onOutOfMemoryError;
197         scope (exit) BIO_free(bpo);
198 
199         // TODO: Uses OpenSSL's default passphrase callbacks if needed.
200         privKey = PEM_read_bio_PrivateKey(bpo, null, null, null);
201         if (!privKey)
202         {
203             version (assert) ERR_print_errors_fp(stderr);
204             return false;
205         }
206 
207         auto pkeyType = EVP_PKEY_id(privKey);
208         if (pkeyType != type) return false;
209 
210         mdctxPriv = EVP_MD_CTX_new();
211         if (!mdctxPriv) onOutOfMemoryError();
212 
213         return true;
214     }
215 
216     bool isValidAlg(JWTAlgorithm alg) { return implAlg == alg; }
217 
218     bool isValid(V, S)(V value, S sign) @trusted if (isToken!V && isToken!S)
219     {
220         version (unittest) {} // no assert behavior is tested in unittest
221         else assert(pubKey, "Public key not set");
222         if (!value.length || !sign.length || !pubKey) return false;
223 
224         static if (type == EVP_PKEY_EC)
225         {
226             if (sign.length != slen) return false;
227 
228             ubyte[256] sbuf;
229             immutable bn_len = slen / 2;
230             auto ec_sig_r = BN_bin2bn(sign.ptr, bn_len, null);
231             auto ec_sig_s = BN_bin2bn(sign.ptr + bn_len, bn_len, null);
232             if (!ec_sig_r || !ec_sig_s) return false;
233 
234             auto ec_sig = ECDSA_SIG_new();
235             if (!ec_sig) onOutOfMemoryError;
236             scope (exit) ECDSA_SIG_free(ec_sig);
237             if (ECDSA_SIG_set0(ec_sig, ec_sig_r, ec_sig_s) != 1) return false;
238 
239             auto siglen = i2d_ECDSA_SIG(ec_sig, null);
240             assert(siglen <= sbuf.length);
241             auto p = &sbuf[0];
242             siglen = i2d_ECDSA_SIG(ec_sig, &p);
243             if (siglen == 0) return false;
244             auto psig = &sbuf[0];
245         }
246         else
247         {
248             auto psig = sign.ptr;
249             immutable siglen = cast(int)sign.length;
250         }
251 
252         // Initialize the DigestVerify operation using evp algorithm
253         if (EVP_DigestVerifyInit(mdctxPub, null, evp, null, pubKey) != 1)
254             return false;
255 
256         if (EVP_DigestVerifyUpdate(mdctxPub, value.ptr, value.length) != 1)
257             return false;
258 
259         auto ret = EVP_DigestVerifyFinal(mdctxPub, psig, siglen);
260         if (ret == -1)
261         {
262             version (assert) ERR_print_errors_fp(stderr);
263             return false;
264         }
265         return ret == 1;
266     }
267 
268     JWTAlgorithm signAlg() { return implAlg; }
269 
270     int sign(S, V)(auto ref S sink, auto ref V value) @trusted
271     {
272         import std.range : put;
273 
274         version (unittest) {} // no assert behavior is tested in unittest
275         else assert(privKey, "Private key not set");
276         if (!value.length || !privKey) return -1;
277 
278         // Initialize the DigestSign operation using alg
279         if (EVP_DigestSignInit(mdctxPriv, null, evp, null, privKey) != 1)
280             return -1;
281 
282         // Call update with the message
283         if (EVP_DigestSignUpdate(mdctxPriv, value.ptr, value.length) != 1)
284             return -1;
285 
286         // First, call EVP_DigestSignFinal with a null sig parameter to get length of sig.
287         ubyte[512] sig;
288         size_t slen;
289         if (EVP_DigestSignFinal(mdctxPriv, null, &slen) != 1)
290             return -1;
291 
292         assert(sig.length >= slen);
293 
294         // Get the signature with real length
295         if (EVP_DigestSignFinal(mdctxPriv, &sig[0], &slen) != 1)
296             return -1;
297 
298         static if (type != EVP_PKEY_EC) put(sink, sig[0..slen]); // just return the signature as is
299         else
300         {
301             // For EC we need to convert to a raw format of R/S.
302             auto ec_key = EVP_PKEY_get1_EC_KEY(privKey); // Get the actual ec_key
303             if (!ec_key) onOutOfMemoryError();
304             immutable degree = EC_GROUP_get_degree(EC_KEY_get0_group(ec_key));
305             EC_KEY_free(ec_key);
306 
307             // Get the sig from the DER encoded version
308             ubyte* ps = cast(ubyte*)sig.ptr;
309             auto ec_sig = d2i_ECDSA_SIG(null, cast(const(ubyte)**)&ps, slen);
310             if (!ec_sig) onOutOfMemoryError();
311             scope (exit) ECDSA_SIG_free(ec_sig);
312 
313             const BIGNUM* ec_sig_r;
314             const BIGNUM* ec_sig_s;
315             ECDSA_SIG_get0(ec_sig, &ec_sig_r, &ec_sig_s);
316             immutable r_len = BN_num_bytes(ec_sig_r);
317             immutable s_len = BN_num_bytes(ec_sig_s);
318             immutable bn_len = (degree + 7) / 8;
319             if ((r_len > bn_len) || (s_len > bn_len))
320                 return -1;
321 
322             ubyte[512] buf;
323             slen = 2 * bn_len;
324             assert(buf.length >= slen);
325 
326             // Pad the bignums with leading zeroes
327             BN_bn2bin(ec_sig_r, buf.ptr + bn_len - r_len);
328             BN_bn2bin(ec_sig_s, buf.ptr + slen - s_len);
329 
330             put(sink, buf[0..slen]);
331         }
332 
333         return cast(int)slen;
334     }
335 }
336 
337 ///
338 @safe unittest
339 {
340     import jwtlited.openssl;
341     import std.stdio;
342 
343     enum EC_PUBKEY = `-----BEGIN PUBLIC KEY-----
344 MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEMlFGAIxe+/zLanxz4bOxTI6daFBk
345 NGyQ+P4bc/RmNEq1NpsogiMB5eXC7jUcD/XqxP9HCIhdRBcQHx7aOo3ayQ==
346 -----END PUBLIC KEY-----`;
347 
348     enum EC_PRIVKEY = `-----BEGIN EC PRIVATE KEY-----
349 MHcCAQEEILvM6E7mLOdndALDyFc3sOgUTb6iVjgwRBtBwYZngSuwoAoGCCqGSM49
350 AwEHoUQDQgAEMlFGAIxe+/zLanxz4bOxTI6daFBkNGyQ+P4bc/RmNEq1NpsogiMB
351 5eXC7jUcD/XqxP9HCIhdRBcQHx7aOo3ayQ==
352 -----END EC PRIVATE KEY-----`;
353 
354     ES256Handler handler;
355     enum payload = `{"foo":42}`;
356     auto ret = handler.loadPKey(EC_PRIVKEY);
357     assert(ret);
358     char[512] tok;
359     immutable len = handler.encode(tok[], payload);
360     assert(len > 0);
361     writeln("ES256: ", tok[0..len]);
362 
363     ret = handler.loadKey(EC_PUBKEY);
364     assert(ret);
365     assert(handler.validate(tok[0..len]));
366     char[32] pay;
367     assert(handler.decode(tok[0..len], pay[]));
368     assert(pay[0..payload.length] == payload);
369 }
370 
371 @safe unittest
372 {
373     static assert(isValidator!HS256Handler);
374     static assert(isSigner!HS256Handler);
375     static assert(isValidator!ES256Handler);
376     static assert(isSigner!ES256Handler);
377     static assert(isValidator!RS256Handler);
378     static assert(isSigner!RS256Handler);
379 }
380 
381 @("ECDSA - Test fail on uninitialized keys")
382 @safe unittest
383 {
384     ES256Handler h;
385     char[512] token;
386     immutable len = h.encode(token[], `{"foo":42}`);
387     assert(len < 0);
388     assert(!h.validate("eyJhbGciOiJFUzI1NiJ9.eyJmb28iOjQyfQ.R_MeWV0nLqRcNk9OrczuhykhKJn2wBZIgmwF87TivMlLGk2KB4Ekec9aXz0dOxBfYQflP6PwdSNjgLdYMECwRA"));
389 }
390 
391 version (unittest) import jwtlited.tests;
392 
393 @("OpenSSL tests")
394 @safe unittest
395 {
396     static void eval(H)(ref immutable TestCase tc)
397     {
398         H h;
399         static if (is(H == HS256Handler) || is(H == HS384Handler) || is(H == HS512Handler))
400             assert(h.loadKey(tc.key) == !!(tc.valid & Valid.key));
401         else
402         {
403             if (tc.test & Test.decode)
404                 assert(h.loadKey(tc.key) == !!(tc.valid & Valid.key));
405             if (tc.test & Test.encode)
406                 assert(h.loadPKey(tc.pkey) == !!(tc.valid & Valid.key));
407         }
408 
409         evalTest(h, tc);
410     }
411 
412     static auto allocatedInCurrentThread() @trusted
413     {
414         import core.memory : GC;
415         static if (__VERSION__ >= 2094) return GC.allocatedInCurrentThread();
416         else return GC.stats().allocatedInCurrentThread;
417     }
418 
419     import std.algorithm : canFind, filter;
420 
421     immutable pre = allocatedInCurrentThread();
422 
423     with (JWTAlgorithm)
424     {
425         static immutable testAlgs = [
426             HS256, HS384, HS512,
427             RS256, RS384, RS512,
428             ES256, ES384, ES512
429         ];
430 
431         foreach (tc; testCases.filter!(a => testAlgs.canFind(a.alg)))
432         {
433             final switch (tc.alg)
434             {
435                 case none: assert(0);
436                 case HS256: eval!HS256Handler(tc); break;
437                 case HS384: eval!HS384Handler(tc); break;
438                 case HS512: eval!HS512Handler(tc); break;
439 
440                 case RS256: eval!RS256Handler(tc); break;
441                 case RS384: eval!RS384Handler(tc); break;
442                 case RS512: eval!RS512Handler(tc); break;
443 
444                 case ES256: eval!ES256Handler(tc); break;
445                 case ES384: eval!ES384Handler(tc); break;
446                 case ES512: eval!ES512Handler(tc); break;
447             }
448         }
449     }
450 
451     assert(allocatedInCurrentThread() - pre == 0); // check for no GC allocations
452 }