1 module jwtlited.gnutls;
2 
3 public import jwtlited;
4 version (assert) import core.stdc.stdio;
5 import core.stdc.string : memcpy;
6 import bindbc.gnutls;
7 
8 version (DYNAMIC_GNUTLS)
9 {
10     shared static this()
11     {
12         import core.stdc.stdio;
13         import loader = bindbc.loader.sharedlib;
14         auto res = loadGnuTLS();
15         if (res < GnuTLSSupport.gnutls_3_5_0)
16         {
17             fprintf(stderr, "Error loading GnuTLS: %d\n", res);
18             foreach(info; loader.errors)
19             {
20                 fprintf(stderr, "\t%s: %s\n", info.error, info.message);
21             }
22             assert(0, "Error loading GnuTLS");
23         }
24     }
25 }
26 
27 alias HS256Handler = HMACImpl!(JWTAlgorithm.HS256);
28 alias HS384Handler = HMACImpl!(JWTAlgorithm.HS384);
29 alias HS512Handler = HMACImpl!(JWTAlgorithm.HS512);
30 
31 /**
32  * Implementation of HS256, HS384 and HS512 signing algorithms.
33  */
34 private struct HMACImpl(JWTAlgorithm implAlg)
35 {
36     static if (implAlg == JWTAlgorithm.HS256) enum signLen = 32;
37     else static if (implAlg == JWTAlgorithm.HS384) enum signLen = 48;
38     else static if (implAlg == JWTAlgorithm.HS512) enum signLen = 64;
39     else static assert(0, "Unsupprted algorithm for HMAC implementation");
40 
41     private
42     {
43         const(char)[] key;
44         gnutls_hmac_hd_t ctx;
45         ubyte[signLen] sigBuf;
46     }
47 
48     @disable this(this);
49 
50     ~this() @trusted
51     {
52         if (ctx) gnutls_hmac_deinit(ctx, null);
53     }
54 
55     bool loadKey(K)(K key) if (isToken!K)
56     {
57         if (!key.length) return false;
58         this.key = cast(const(char)[])key;
59 
60         immutable ret = () @trusted
61         {
62             static if (implAlg == JWTAlgorithm.HS256) alias alg = gnutls_mac_algorithm_t.GNUTLS_MAC_SHA256;
63             else static if (implAlg == JWTAlgorithm.HS384) alias alg = gnutls_mac_algorithm_t.GNUTLS_MAC_SHA384;
64             else static if (implAlg == JWTAlgorithm.HS512) alias alg = gnutls_mac_algorithm_t.GNUTLS_MAC_SHA512;
65 
66             assert(signLen == gnutls_hmac_get_len(alg));
67             return gnutls_hmac_init(&ctx, alg, this.key.ptr, key.length);
68         }();
69         assert(ctx);
70         return ret == 0;
71     }
72 
73     bool isValidAlg(JWTAlgorithm alg) { return implAlg == alg; }
74 
75     bool isValid(V, S)(V value, S sign) if (isToken!V && isToken!S)
76     {
77         if (!genSignature(value)) return false;
78         return cast(const(ubyte)[])sign == sigBuf[0..signLen];
79     }
80 
81     JWTAlgorithm signAlg() { return implAlg; }
82 
83     int sign(S, V)(auto ref S sink, auto ref V value) if (isToken!V)
84     {
85         import std.range : put;
86         if (!genSignature(value)) return -1;
87         put(sink, sigBuf[0..signLen]);
88         return signLen;
89     }
90 
91     private bool genSignature(V)(V value) @trusted
92     {
93         assert(key.length, "Secret key not set");
94         if (!key.length || !value.length) return false;
95 
96         auto ret = gnutls_hmac(ctx, value.ptr, value.length);
97         if (ret < 0) return false;
98 
99         gnutls_hmac_output(ctx, sigBuf.ptr);
100         return true;
101     }
102 }
103 
104 alias RS256Handler = PEMImpl!(JWTAlgorithm.RS256);
105 alias RS384Handler = PEMImpl!(JWTAlgorithm.RS384);
106 alias RS512Handler = PEMImpl!(JWTAlgorithm.RS512);
107 alias ES256Handler = PEMImpl!(JWTAlgorithm.ES256);
108 alias ES384Handler = PEMImpl!(JWTAlgorithm.ES384);
109 alias ES512Handler = PEMImpl!(JWTAlgorithm.ES512);
110 
111 /**
112  * Implementation of ES256, ES384 and ES512 signing algorithms.
113  */
114 private struct PEMImpl(JWTAlgorithm implAlg)
115 {
116     private
117     {
118         gnutls_x509_privkey_t x509key;
119         gnutls_privkey_t privKey;
120         gnutls_pubkey_t pubKey;
121 
122         import std.algorithm : among;
123         static if (implAlg.among(JWTAlgorithm.ES256, JWTAlgorithm.ES384, JWTAlgorithm.ES512))
124             enum pkAlg = gnutls_pk_algorithm_t.GNUTLS_PK_ECDSA;
125         else static if (implAlg.among(JWTAlgorithm.RS256, JWTAlgorithm.RS384, JWTAlgorithm.RS512))
126             enum pkAlg = gnutls_pk_algorithm_t.GNUTLS_PK_RSA;
127         else static assert(0, "Unsupprted algorithm for PEM implementation");
128 
129         static if (implAlg == JWTAlgorithm.ES256) alias alg = gnutls_sign_algorithm_t.GNUTLS_SIGN_ECDSA_SHA256;
130         else static if (implAlg == JWTAlgorithm.ES384) alias alg = gnutls_sign_algorithm_t.GNUTLS_SIGN_ECDSA_SHA384;
131         else static if (implAlg == JWTAlgorithm.ES512) alias alg = gnutls_sign_algorithm_t.GNUTLS_SIGN_ECDSA_SHA512;
132         else static if (implAlg == JWTAlgorithm.RS256) alias alg = gnutls_sign_algorithm_t.GNUTLS_SIGN_RSA_SHA256;
133         else static if (implAlg == JWTAlgorithm.RS384) alias alg = gnutls_sign_algorithm_t.GNUTLS_SIGN_RSA_SHA384;
134         else static if (implAlg == JWTAlgorithm.RS512) alias alg = gnutls_sign_algorithm_t.GNUTLS_SIGN_RSA_SHA512;
135     }
136 
137     @disable this(this);
138 
139     ~this() @trusted
140     {
141         if (x509key) gnutls_x509_privkey_deinit(x509key);
142         if (privKey) gnutls_privkey_deinit(privKey);
143         if (pubKey) gnutls_pubkey_deinit(pubKey);
144     }
145 
146     bool loadKey(K)(K key) @trusted if (isToken!K)
147     {
148         if (!key.length) return false;
149 
150         if (gnutls_pubkey_init(&pubKey)) return false;
151 
152         gnutls_datum_t cert_dat = gnutls_datum_t(cast(ubyte*)key.ptr, cast(uint)key.length);
153         if (gnutls_pubkey_import(pubKey, &cert_dat, gnutls_x509_crt_fmt_t.GNUTLS_X509_FMT_PEM))
154         {
155             gnutls_pubkey_deinit(pubKey);
156             pubKey = null;
157             return false;
158         }
159 
160         return true;
161     }
162 
163     bool loadPKey(K)(K key) @trusted if (isToken!K)
164     {
165         if (gnutls_x509_privkey_init(&x509key)) return false;
166 
167         gnutls_datum_t keyData = gnutls_datum_t(cast(ubyte*)key.ptr, cast(uint)key.length);
168         if (gnutls_x509_privkey_import(x509key, &keyData, gnutls_x509_crt_fmt_t.GNUTLS_X509_FMT_PEM))
169             goto err;
170 
171         if (gnutls_privkey_init(&privKey)) goto err;
172         if (gnutls_privkey_import_x509(privKey, x509key, 0)) goto err;
173         if (pkAlg != gnutls_privkey_get_pk_algorithm(privKey, null)) goto err;
174 
175         return true;
176 
177         err:
178             if (x509key) { gnutls_x509_privkey_deinit(x509key); x509key = null; }
179             if (privKey) { gnutls_privkey_deinit(privKey); privKey = null; }
180             return false;
181     }
182 
183     bool isValidAlg(JWTAlgorithm alg) { return implAlg == alg; }
184 
185     bool isValid(V, S)(V value, S sign) @trusted if (isToken!V && isToken!S)
186     {
187         version (unittest) {} // no assert behavior is tested in unittest
188         else assert(pubKey, "Public key not set");
189         if (!value.length || !sign.length || !pubKey) return false;
190 
191         gnutls_datum_t data = gnutls_datum_t(cast(ubyte*)value.ptr, cast(uint)value.length);
192         static if (pkAlg == gnutls_pk_algorithm_t.GNUTLS_PK_RSA)
193         {
194             gnutls_datum_t sig_dat = gnutls_datum_t(sign.ptr, cast(uint)sign.length);
195             if (gnutls_pubkey_verify_data2(pubKey, alg, 0, &data, &sig_dat))
196                 return false;
197         }
198         else
199         {
200             // Rebuild signature using r and s extracted from sig
201 
202             gnutls_datum_t r, s;
203             static if (implAlg == JWTAlgorithm.ES256)
204             {
205                 if (sign.length != 64) return false;
206                 r.size = 32;
207                 r.data = sign.ptr;
208                 s.size = 32;
209                 s.data = sign.ptr + 32;
210             }
211             else static if (implAlg == JWTAlgorithm.ES384)
212             {
213                 if (sign.length != 96) return false;
214                 r.size = 48;
215                 r.data = sign.ptr;
216                 s.size = 48;
217                 s.data = sign.ptr + 48;
218             }
219             else static if (implAlg == JWTAlgorithm.ES512)
220             {
221                 if (sign.length != 132) return false;
222                 r.size = 66;
223                 r.data = sign.ptr;
224                 s.size = 66;
225                 s.data = sign.ptr + 66;
226             }
227             else static assert(0);
228 
229             gnutls_datum_t sig_dat;
230             scope (exit)
231             {
232                 if (sig_dat.data) gnutls_free(sig_dat.data);
233             }
234 
235             if (gnutls_encode_rs_value(&sig_dat, &r, &s)) return false;
236             if (gnutls_pubkey_verify_data2(pubKey, alg, 0, &data, &sig_dat)) return false;
237         }
238         return true;
239     }
240 
241     JWTAlgorithm signAlg() { return implAlg; }
242 
243     int sign(S, V)(auto ref S sink, auto ref V value) @trusted
244     {
245         import std.range : put;
246 
247         version (unittest) {} // no assert behavior is tested in unittest
248         else assert(privKey, "Private key not set");
249         if (!value.length || !privKey) return -1;
250 
251         gnutls_datum_t body_dat = gnutls_datum_t(cast(ubyte*)value.ptr, cast(uint)value.length);
252         gnutls_datum_t sig_dat;
253 
254         immutable ret = gnutls_privkey_sign_data2(privKey, alg, 0, &body_dat, &sig_dat);
255         if (ret) return -1;
256         scope (exit) gnutls_free(sig_dat.data);
257 
258         static if (pkAlg == gnutls_pk_algorithm_t.GNUTLS_PK_RSA)
259         {
260             put(sink, sig_dat.data[0..sig_dat.size]);
261             return sig_dat.size;
262         }
263         else
264         {
265             gnutls_datum_t r, s;
266             if (gnutls_decode_rs_value(&sig_dat, &r, &s)) return -1;
267             scope (exit)
268             {
269                 gnutls_free(r.data);
270                 gnutls_free(s.data);
271             }
272 
273             static if (implAlg == JWTAlgorithm.ES256) enum adj = 32;
274             else static if (implAlg == JWTAlgorithm.ES384) enum adj = 48;
275             else static if (implAlg == JWTAlgorithm.ES512) enum adj = 66;
276 
277             int r_padding, s_padding, r_out_padding, s_out_padding;
278             size_t out_size;
279 
280             if (r.size > adj) r_padding = r.size - adj;
281             else if (r.size < adj) r_out_padding = adj - r.size;
282 
283             if (s.size > adj) s_padding = s.size - adj;
284             else if (s.size < adj) s_out_padding = adj - s.size;
285 
286             out_size = adj << 1;
287             ubyte[512] buf;
288             assert(buf.length >= out_size);
289 
290             memcpy(buf.ptr + r_out_padding, r.data + r_padding, r.size - r_padding);
291             memcpy(
292                 buf.ptr + (r.size - r_padding + r_out_padding) + s_out_padding,
293                 s.data + s_padding,
294                 s.size - s_padding
295             );
296 
297             assert((r.size - r_padding + r_out_padding) + (s.size - s_padding + s_out_padding) == out_size);
298             put(sink, buf[0..out_size]);
299             return (r.size - r_padding + r_out_padding) + (s.size - s_padding + s_out_padding);
300         }
301     }
302 }
303 
304 version (unittest) import jwtlited.tests;
305 
306 @("GnuTLS tests")
307 @safe unittest
308 {
309     static void eval(H)(ref immutable TestCase tc)
310     {
311         H h;
312         static if (is(H == HS256Handler) || is(H == HS384Handler) || is(H == HS512Handler))
313             assert(h.loadKey(tc.key) == !!(tc.valid & Valid.key));
314         else
315         {
316             if (tc.test & Test.decode)
317                 assert(h.loadKey(tc.key) == !!(tc.valid & Valid.key));
318             if (tc.test & Test.encode)
319                 assert(h.loadPKey(tc.pkey) == !!(tc.valid & Valid.key));
320         }
321 
322         evalTest(h, tc);
323     }
324 
325     static auto allocatedInCurrentThread() @trusted
326     {
327         import core.memory : GC;
328         static if (__VERSION__ >= 2094) return GC.allocatedInCurrentThread();
329         else return GC.stats().allocatedInCurrentThread;
330     }
331 
332     import std.algorithm : canFind, filter;
333 
334     immutable pre = allocatedInCurrentThread();
335 
336     with (JWTAlgorithm)
337     {
338         static immutable testAlgs = [
339             HS256, HS384, HS512,
340             RS256, RS384, RS512,
341             ES256, ES384, ES512
342         ];
343 
344         foreach (tc; testCases.filter!(a => testAlgs.canFind(a.alg)))
345         {
346             final switch (tc.alg)
347             {
348                 case none: assert(0);
349                 case HS256: eval!HS256Handler(tc); break;
350                 case HS384: eval!HS384Handler(tc); break;
351                 case HS512: eval!HS512Handler(tc); break;
352 
353                 case RS256: eval!RS256Handler(tc); break;
354                 case RS384: eval!RS384Handler(tc); break;
355                 case RS512: eval!RS512Handler(tc); break;
356 
357                 case ES256: eval!ES256Handler(tc); break;
358                 case ES384: eval!ES384Handler(tc); break;
359                 case ES512: eval!ES512Handler(tc); break;
360             }
361         }
362     }
363 
364     assert(allocatedInCurrentThread() - pre == 0); // check for no GC allocations
365 }
366