1 module jwtlited;
2 
3 import std.algorithm;
4 import std.base64;
5 import std.string;
6 import std.traits;
7 import bc.string.string;
8 
9 /// Supported algorithms
10 enum JWTAlgorithm
11 {
12     none,
13     HS256,
14     HS384,
15     HS512,
16     RS256,
17     RS384,
18     RS512,
19     ES256,
20     ES384,
21     ES512
22 }
23 
24 /**
25  * Structure that can be used to handle tokens without signatures.
26  * Requires that the token has `"alg": "none"` in the header and no sign part.
27  */
28 struct NoneHandler
29 {
30     @safe pure nothrow @nogc:
31 
32     bool isValidAlg(JWTAlgorithm alg)
33     {
34         return alg == JWTAlgorithm.none;
35     }
36 
37     bool isValid(V, S)(V value, S sign) if (isToken!V && isToken!S)
38     {
39         return sign.length == 0;
40     }
41 
42     JWTAlgorithm signAlg() { return JWTAlgorithm.none; }
43 
44     int sign(S, V)(auto ref S sink, auto ref V value)
45     {
46         return 0;
47     }
48 }
49 
50 ///
51 @safe unittest
52 {
53     import jwtlited;
54     import std.stdio;
55 
56     NoneHandler handler;
57     char[512] token;
58     enum payload = `{"foo":42}`;
59     immutable len = handler.encode(token[], payload);
60     assert(len > 0);
61     writeln("NONE: ", token[0..len]);
62 
63     assert(handler.validate(token[0..len]));
64     char[32] pay;
65     assert(handler.decode(token[0..len], pay[]));
66     assert(pay[0..payload.length] == payload);
67 }
68 
69 unittest
70 {
71     static assert(isValidator!NoneHandler);
72     static assert(isSigner!NoneHandler);
73 }
74 
75 /**
76  * Validator that accepts any JWT algorithm and ignores it's signature at all.
77  * Can be used to decode token without it's signature validation.
78  */
79 struct AnyAlgValidator
80 {
81     @safe pure nothrow @nogc:
82 
83     bool isValidAlg(JWTAlgorithm alg) { return true; }
84 
85     bool isValid(V, S)(V value, S sign) if (isToken!V && isToken!S)
86     {
87         return true;
88     }
89 }
90 
91 unittest
92 {
93     static assert(isValidator!AnyAlgValidator);
94     static assert(!isSigner!AnyAlgValidator);
95 }
96 
97 private
98 {
99     immutable string[] base64HeaderStrings;
100     immutable string[] algStrings;
101 }
102 
103 shared static this()
104 {
105     import std.algorithm : map;
106     import std.array : array;
107     import std.base64 : Base64URLNoPadding;
108     import std.format : format;
109     import std.traits : EnumMembers;
110 
111     // build header hashes for JWT algorithms
112     base64HeaderStrings = [EnumMembers!JWTAlgorithm]
113         .map!(a => Base64URLNoPadding.encode(cast(ubyte[])(format!`{"alg":"%s"}`(a))))
114         .array;
115 
116     algStrings = [EnumMembers!JWTAlgorithm].map!(a => format!"%s"(a)).array;
117 }
118 
119 // TODO: use SSE4.2 optimised token parser to also check for invalid characters in token while advancing between '.'
120 
121 /**
122  * Decodes and validates the JWT token.
123  *
124  * It always base64 decode the header and checks "alg" value in it.
125  * Payload is decoded only when payloadSink is provided, otherwise it's just skipped.
126  * Sign part is base64 decoded and passed to the provided validator implementation to check.
127  *
128  * JSON header and payload validation is out of scope of this function. It just checks the basic structure of JWT.
129  * Note: Only compact encoded JWS format is supported.
130  */
131 bool decode(V, T, HS, PS)(auto ref V validator, T token, auto ref HS headSink, auto ref PS payloadSink)
132     if (isToken!T && isValidator!V)
133 {
134     import std.range : put;
135 
136     // get header part
137     immutable hlen = (cast(const(ubyte)[])token).countUntil('.');
138     if (hlen <= 0) return false;
139 
140     // decode header to check used algorithm
141     static String hdrBuf;
142     hdrBuf.clear();
143 
144     // TODO: should pass, see similar: https://issues.dlang.org/show_bug.cgi?id=18168
145     // problem only with OutputRange
146     hdrBuf.reserve(Base64URLNoPadding.decodeLength(hlen));
147     auto pc = () @trusted { return &hdrBuf[0]; }(); // workaround as Base64.decode doesn't accept char[]
148     try () @trusted { Base64URLNoPadding.decode(token[0..hlen], (cast(ubyte*)pc)[0..hdrBuf.length]); }();
149     catch (Exception) return false;
150 
151     JWTAlgorithm alg;
152     immutable algret = parseHeaderAlgorithm(hdrBuf[], alg);
153     if (algret != 0) return false;
154     if (!validator.isValidAlg(alg)) return false;
155 
156     // find end of the payload
157     immutable plen = (cast(const(ubyte)[])token[hlen+1..$]).countUntil('.');
158     if (plen <= 0) return false;
159 
160     // check that sign is the last part of the token
161     if ((cast(const(ubyte)[])token[hlen+plen+2..$]).countUntil('.') >= 0) return false; // JWS has only 3 parts
162 
163     // copy header if requested
164     static if (!is(HS == typeof(null)))
165         put(headSink, hdrBuf[]);
166 
167     // decode payload if requested
168     static if (!is(PS == typeof(null)))
169     {
170         static if (isArray!PS && is(ForeachType!PS == char))
171         {
172             auto ps = () @trusted
173             {
174                 auto p = payloadSink.ptr;
175                 return (cast(ubyte*)p)[0..payloadSink.length];
176             }();
177         }
178         else alias ps = payloadSink;
179 
180         try () @trusted { Base64URLNoPadding.decode(token[hlen+1 .. hlen+plen+1], ps); }(); // TODO: see same problem above
181         catch (Exception) return false;
182     }
183 
184     // validate signature with the provided validator
185     ubyte[512] sigBuf; // RSA 4096 has 512B sign, we don't expect more)
186     auto sigB64 = token[hlen+plen+2..$];
187     ubyte[] sig;
188     if (sigB64.length)
189     {
190         if (Base64URLNoPadding.decodeLength(sigB64.length) > sigBuf.length) return false;
191         try sig = Base64URLNoPadding.decode(sigB64, sigBuf[]);
192         catch (Exception) return false;
193     }
194     return validator.isValid(token[0..hlen+plen+1], sig);
195 }
196 
197 /// ditto
198 bool decode(V, T, S)(auto ref V validator, T token, auto ref S payloadSink)
199     if (isToken!T && isValidator!V)
200 {
201     return decode(validator, token, null, payloadSink);
202 }
203 
204 /**
205  * Decodes token payload without signature validation.
206  * Only token header is checked for any "alg" and basic token structure is validated.
207  */
208 bool decodePayload(T, S)(T token, auto ref S payloadSink)
209     if (isToken!T)
210 {
211     return decode(AnyAlgValidator.init, token, null, payloadSink);
212 }
213 
214 /**
215  * Validates token format and signature with a provided validator.
216  * It doesn't base64 decode the payload.
217  */
218 bool validate(V, T)(auto ref V validator, T token)
219 {
220     return decode(validator, token, null, null);
221 }
222 
223 /**
224  * Endodes token using provided Singer algorithm and already prepared payload.
225  *
226  * If header is also provided it's checked for correct `alg` header field and added if not set.
227  *
228  * Both header and payload are expected to be a valid json object serialized string.
229  *
230  * Returns: -1 on error, otherwise number of characters written to the output.
231  */
232 int encode(S, O, P)(auto ref S signer, auto ref O output, P payload)
233     if (isSigner!S && isToken!P)
234 {
235     return encodeImpl!false(signer, output, base64HeaderStrings[signer.signAlg], payload);
236 }
237 
238 /// ditto
239 int encode(S, O, H, P)(auto ref S signer, auto ref O output, H header, P payload)
240     if (isSigner!S && isToken!H && isToken!P)
241 {
242     if (header.length) return encodeImpl!true(signer, output, header, payload);
243     return encodeImpl!false(signer, output, base64HeaderStrings[signer.signAlg], payload);
244 }
245 
246 private int encodeImpl(bool checkHeader, S, O, H, P)(auto ref S signer, auto ref O output, H header, P payload)
247     if (isSigner!S && isToken!H && isToken!P)
248 {
249     import std.range : put;
250 
251     static String tmp;
252     tmp.clear();
253 
254     static if (checkHeader)
255     {
256         assert(header.length);
257         if (header[0] != '{' || header[$-1] != '}') return -1;
258         JWTAlgorithm alg;
259         immutable algret = parseHeaderAlgorithm(header, alg);
260         if (algret < -1) return -1;
261         if (algret == 0 && alg != signer.signAlg) return -1;
262         if (algret == -1)
263         {
264             String hdrtmp;
265             hdrtmp ~= `{"alg":"`;
266             hdrtmp ~= algStrings[signer.signAlg];
267             hdrtmp ~= `",`;
268             hdrtmp ~= header[1..$];
269             tmp.reserve(Base64URLNoPadding.encodeLength(hdrtmp.length));
270             auto phc = () @trusted { return (cast(ubyte*)&tmp[0])[0..Base64URLNoPadding.encodeLength(hdrtmp.length)]; }();
271             Base64URLNoPadding.encode(hdrtmp[], phc);
272         }
273         else tmp ~= header;
274     }
275     else tmp ~= header;
276 
277     tmp ~= '.';
278     auto idx = tmp.length;
279     tmp.reserve(Base64URLNoPadding.encodeLength(payload.length));
280     auto pc = () @trusted { return (cast(ubyte*)&tmp[idx])[0..Base64URLNoPadding.encodeLength(payload.length)]; }();
281     Base64URLNoPadding.encode(payload, pc);
282     tmp ~= '.';
283 
284     ubyte[512] sigtmp;
285     auto len = signer.sign(sigtmp[], tmp[0..$-1]);
286     if (len < 0) return -1;
287 
288     int res = cast(int)tmp.length;
289 
290     if (len)
291     {
292         idx = tmp.length;
293         tmp.reserve(Base64URLNoPadding.encodeLength(len));
294         res += Base64URLNoPadding.encode(sigtmp[0..len], tmp[idx..$]).length;
295     }
296 
297     put(output, tmp[]);
298     return res;
299 }
300 
301 unittest
302 {
303     NoneHandler none;
304     String buf;
305     immutable ret = encode(none, buf, `{"foo":"bar"}`, `{"baz":42}`);
306     assert(ret);
307 
308     import std.stdio;
309     writeln(buf[]);
310 }
311 
312 template isToken(T)
313 {
314     import std.traits : isArray, Unqual, ForeachType;
315     enum isToken = isArray!T && is(Unqual!(ForeachType!T) : char);
316 }
317 
318 unittest
319 {
320     static assert(isToken!string);
321     static assert(isToken!(ubyte[]));
322 }
323 
324 template isValidator(V)
325 {
326     enum isValidator = __traits(hasMember, V, "isValidAlg") && __traits(hasMember, V, "isValid");
327 }
328 
329 template isSigner(S)
330 {
331     enum isSigner = __traits(hasMember, S, "signAlg") && __traits(hasMember, S, "sign");
332 }
333 
334 // returns 0 ok, -1 missing, -2 error, -3 unknown or unsupported alg
335 private int parseHeaderAlgorithm(H)(H hdr, out JWTAlgorithm alg)
336     if (isToken!H)
337 {
338     import std.ascii : isAlphaNum, isWhite;
339 
340     // pure man's JSON parse to find "alg" in the header
341     auto algIdx = hdr.countUntil(`"alg":`);
342     if (algIdx < 0) return -1; // alg value is REQUIRED
343     algIdx += `"alg":`.length;
344     while (algIdx < hdr.length && hdr[algIdx].isWhite) algIdx++; // skip possible whitespaces
345     if (algIdx == hdr.length || hdr[algIdx] != '"') return -2;
346     auto algStart = ++algIdx;
347     // NOTE: expected only alphanum characters for supported JWT algorithms, but needs to be changed to support JWE
348     while (algIdx < hdr.length && hdr[algIdx].isAlphaNum) algIdx++;
349     if (algIdx == hdr.length || hdr[algIdx] != '"') return -2;
350     auto algVal = hdr[algStart..algIdx];
351 
352     // get used algorithm
353     immutable salg = algStrings.countUntil(algVal);
354     if (salg < 0) return -3;
355     alg = cast(JWTAlgorithm)salg;
356     return 0;
357 }