1 module jwtlited; 2 3 import std.algorithm; 4 import std.base64; 5 import std.string; 6 import std.traits; 7 import bc.string.string; 8 9 /// Supported algorithms 10 enum JWTAlgorithm 11 { 12 none, 13 HS256, 14 HS384, 15 HS512, 16 RS256, 17 RS384, 18 RS512, 19 ES256, 20 ES384, 21 ES512 22 } 23 24 /** 25 * Structure that can be used to handle tokens without signatures. 26 * Requires that the token has `"alg": "none"` in the header and no sign part. 27 */ 28 struct NoneHandler 29 { 30 @safe pure nothrow @nogc: 31 32 bool isValidAlg(JWTAlgorithm alg) 33 { 34 return alg == JWTAlgorithm.none; 35 } 36 37 bool isValid(V, S)(V value, S sign) if (isToken!V && isToken!S) 38 { 39 return sign.length == 0; 40 } 41 42 JWTAlgorithm signAlg() { return JWTAlgorithm.none; } 43 44 int sign(S, V)(auto ref S sink, auto ref V value) 45 { 46 return 0; 47 } 48 } 49 50 /// 51 @safe unittest 52 { 53 import jwtlited; 54 import std.stdio; 55 56 NoneHandler handler; 57 char[512] token; 58 enum payload = `{"foo":42}`; 59 immutable len = handler.encode(token[], payload); 60 assert(len > 0); 61 writeln("NONE: ", token[0..len]); 62 63 assert(handler.validate(token[0..len])); 64 char[32] pay; 65 assert(handler.decode(token[0..len], pay[])); 66 assert(pay[0..payload.length] == payload); 67 } 68 69 unittest 70 { 71 static assert(isValidator!NoneHandler); 72 static assert(isSigner!NoneHandler); 73 } 74 75 /** 76 * Validator that accepts any JWT algorithm and ignores it's signature at all. 77 * Can be used to decode token without it's signature validation. 78 */ 79 struct AnyAlgValidator 80 { 81 @safe pure nothrow @nogc: 82 83 bool isValidAlg(JWTAlgorithm alg) { return true; } 84 85 bool isValid(V, S)(V value, S sign) if (isToken!V && isToken!S) 86 { 87 return true; 88 } 89 } 90 91 unittest 92 { 93 static assert(isValidator!AnyAlgValidator); 94 static assert(!isSigner!AnyAlgValidator); 95 } 96 97 private 98 { 99 immutable string[] base64HeaderStrings; 100 immutable string[] algStrings; 101 } 102 103 shared static this() 104 { 105 import std.algorithm : map; 106 import std.array : array; 107 import std.base64 : Base64URLNoPadding; 108 import std.format : format; 109 import std.traits : EnumMembers; 110 111 // build header hashes for JWT algorithms 112 base64HeaderStrings = [EnumMembers!JWTAlgorithm] 113 .map!(a => Base64URLNoPadding.encode(cast(ubyte[])(format!`{"alg":"%s"}`(a)))) 114 .array; 115 116 algStrings = [EnumMembers!JWTAlgorithm].map!(a => format!"%s"(a)).array; 117 } 118 119 // TODO: use SSE4.2 optimised token parser to also check for invalid characters in token while advancing between '.' 120 121 /** 122 * Decodes and validates the JWT token. 123 * 124 * It always base64 decode the header and checks "alg" value in it. 125 * Payload is decoded only when payloadSink is provided, otherwise it's just skipped. 126 * Sign part is base64 decoded and passed to the provided validator implementation to check. 127 * 128 * JSON header and payload validation is out of scope of this function. It just checks the basic structure of JWT. 129 * Note: Only compact encoded JWS format is supported. 130 */ 131 bool decode(V, T, HS, PS)(auto ref V validator, T token, auto ref HS headSink, auto ref PS payloadSink) 132 if (isToken!T && isValidator!V) 133 { 134 import std.range : put; 135 136 // get header part 137 immutable hlen = (cast(const(ubyte)[])token).countUntil('.'); 138 if (hlen <= 0) return false; 139 140 // decode header to check used algorithm 141 static String hdrBuf; 142 hdrBuf.clear(); 143 144 // TODO: should pass, see similar: https://issues.dlang.org/show_bug.cgi?id=18168 145 // problem only with OutputRange 146 hdrBuf.reserve(Base64URLNoPadding.decodeLength(hlen)); 147 auto pc = () @trusted { return &hdrBuf[0]; }(); // workaround as Base64.decode doesn't accept char[] 148 try () @trusted { Base64URLNoPadding.decode(token[0..hlen], (cast(ubyte*)pc)[0..hdrBuf.length]); }(); 149 catch (Exception) return false; 150 151 JWTAlgorithm alg; 152 immutable algret = parseHeaderAlgorithm(hdrBuf[], alg); 153 if (algret != 0) return false; 154 if (!validator.isValidAlg(alg)) return false; 155 156 // find end of the payload 157 immutable plen = (cast(const(ubyte)[])token[hlen+1..$]).countUntil('.'); 158 if (plen <= 0) return false; 159 160 // check that sign is the last part of the token 161 if ((cast(const(ubyte)[])token[hlen+plen+2..$]).countUntil('.') >= 0) return false; // JWS has only 3 parts 162 163 // copy header if requested 164 static if (!is(HS == typeof(null))) 165 put(headSink, hdrBuf[]); 166 167 // decode payload if requested 168 static if (!is(PS == typeof(null))) 169 { 170 static if (isArray!PS && is(ForeachType!PS == char)) 171 { 172 auto ps = () @trusted 173 { 174 auto p = payloadSink.ptr; 175 return (cast(ubyte*)p)[0..payloadSink.length]; 176 }(); 177 } 178 else alias ps = payloadSink; 179 180 try () @trusted { Base64URLNoPadding.decode(token[hlen+1 .. hlen+plen+1], ps); }(); // TODO: see same problem above 181 catch (Exception) return false; 182 } 183 184 // validate signature with the provided validator 185 ubyte[512] sigBuf; // RSA 4096 has 512B sign, we don't expect more) 186 auto sigB64 = token[hlen+plen+2..$]; 187 ubyte[] sig; 188 if (sigB64.length) 189 { 190 if (Base64URLNoPadding.decodeLength(sigB64.length) > sigBuf.length) return false; 191 try sig = Base64URLNoPadding.decode(sigB64, sigBuf[]); 192 catch (Exception) return false; 193 } 194 return validator.isValid(token[0..hlen+plen+1], sig); 195 } 196 197 /// ditto 198 bool decode(V, T, S)(auto ref V validator, T token, auto ref S payloadSink) 199 if (isToken!T && isValidator!V) 200 { 201 return decode(validator, token, null, payloadSink); 202 } 203 204 /** 205 * Decodes token payload without signature validation. 206 * Only token header is checked for any "alg" and basic token structure is validated. 207 */ 208 bool decodePayload(T, S)(T token, auto ref S payloadSink) 209 if (isToken!T) 210 { 211 return decode(AnyAlgValidator.init, token, null, payloadSink); 212 } 213 214 /** 215 * Validates token format and signature with a provided validator. 216 * It doesn't base64 decode the payload. 217 */ 218 bool validate(V, T)(auto ref V validator, T token) 219 { 220 return decode(validator, token, null, null); 221 } 222 223 /** 224 * Endodes token using provided Singer algorithm and already prepared payload. 225 * 226 * If header is also provided it's checked for correct `alg` header field and added if not set. 227 * 228 * Both header and payload are expected to be a valid json object serialized string. 229 * 230 * Returns: -1 on error, otherwise number of characters written to the output. 231 */ 232 int encode(S, O, P)(auto ref S signer, auto ref O output, P payload) 233 if (isSigner!S && isToken!P) 234 { 235 return encodeImpl!false(signer, output, base64HeaderStrings[signer.signAlg], payload); 236 } 237 238 /// ditto 239 int encode(S, O, H, P)(auto ref S signer, auto ref O output, H header, P payload) 240 if (isSigner!S && isToken!H && isToken!P) 241 { 242 if (header.length) return encodeImpl!true(signer, output, header, payload); 243 return encodeImpl!false(signer, output, base64HeaderStrings[signer.signAlg], payload); 244 } 245 246 private int encodeImpl(bool checkHeader, S, O, H, P)(auto ref S signer, auto ref O output, H header, P payload) 247 if (isSigner!S && isToken!H && isToken!P) 248 { 249 import std.range : put; 250 251 static String tmp; 252 tmp.clear(); 253 254 static if (checkHeader) 255 { 256 assert(header.length); 257 if (header[0] != '{' || header[$-1] != '}') return -1; 258 JWTAlgorithm alg; 259 immutable algret = parseHeaderAlgorithm(header, alg); 260 if (algret < -1) return -1; 261 if (algret == 0 && alg != signer.signAlg) return -1; 262 if (algret == -1) 263 { 264 String hdrtmp; 265 hdrtmp ~= `{"alg":"`; 266 hdrtmp ~= algStrings[signer.signAlg]; 267 hdrtmp ~= `",`; 268 hdrtmp ~= header[1..$]; 269 tmp.reserve(Base64URLNoPadding.encodeLength(hdrtmp.length)); 270 auto phc = () @trusted { return (cast(ubyte*)&tmp[0])[0..Base64URLNoPadding.encodeLength(hdrtmp.length)]; }(); 271 Base64URLNoPadding.encode(hdrtmp[], phc); 272 } 273 else tmp ~= header; 274 } 275 else tmp ~= header; 276 277 tmp ~= '.'; 278 auto idx = tmp.length; 279 tmp.reserve(Base64URLNoPadding.encodeLength(payload.length)); 280 auto pc = () @trusted { return (cast(ubyte*)&tmp[idx])[0..Base64URLNoPadding.encodeLength(payload.length)]; }(); 281 Base64URLNoPadding.encode(payload, pc); 282 tmp ~= '.'; 283 284 ubyte[512] sigtmp; 285 auto len = signer.sign(sigtmp[], tmp[0..$-1]); 286 if (len < 0) return -1; 287 288 int res = cast(int)tmp.length; 289 290 if (len) 291 { 292 idx = tmp.length; 293 tmp.reserve(Base64URLNoPadding.encodeLength(len)); 294 res += Base64URLNoPadding.encode(sigtmp[0..len], tmp[idx..$]).length; 295 } 296 297 put(output, tmp[]); 298 return res; 299 } 300 301 unittest 302 { 303 NoneHandler none; 304 String buf; 305 immutable ret = encode(none, buf, `{"foo":"bar"}`, `{"baz":42}`); 306 assert(ret); 307 308 import std.stdio; 309 writeln(buf[]); 310 } 311 312 template isToken(T) 313 { 314 import std.traits : isArray, Unqual, ForeachType; 315 enum isToken = isArray!T && is(Unqual!(ForeachType!T) : char); 316 } 317 318 unittest 319 { 320 static assert(isToken!string); 321 static assert(isToken!(ubyte[])); 322 } 323 324 template isValidator(V) 325 { 326 enum isValidator = __traits(hasMember, V, "isValidAlg") && __traits(hasMember, V, "isValid"); 327 } 328 329 template isSigner(S) 330 { 331 enum isSigner = __traits(hasMember, S, "signAlg") && __traits(hasMember, S, "sign"); 332 } 333 334 // returns 0 ok, -1 missing, -2 error, -3 unknown or unsupported alg 335 private int parseHeaderAlgorithm(H)(H hdr, out JWTAlgorithm alg) 336 if (isToken!H) 337 { 338 import std.ascii : isAlphaNum, isWhite; 339 340 // pure man's JSON parse to find "alg" in the header 341 auto algIdx = hdr.countUntil(`"alg":`); 342 if (algIdx < 0) return -1; // alg value is REQUIRED 343 algIdx += `"alg":`.length; 344 while (algIdx < hdr.length && hdr[algIdx].isWhite) algIdx++; // skip possible whitespaces 345 if (algIdx == hdr.length || hdr[algIdx] != '"') return -2; 346 auto algStart = ++algIdx; 347 // NOTE: expected only alphanum characters for supported JWT algorithms, but needs to be changed to support JWE 348 while (algIdx < hdr.length && hdr[algIdx].isAlphaNum) algIdx++; 349 if (algIdx == hdr.length || hdr[algIdx] != '"') return -2; 350 auto algVal = hdr[algStart..algIdx]; 351 352 // get used algorithm 353 immutable salg = algStrings.countUntil(algVal); 354 if (salg < 0) return -3; 355 alg = cast(JWTAlgorithm)salg; 356 return 0; 357 }