1package jwx
2
3import (
4	"bytes"
5	"encoding/json"
6)
7
8type FormatKind int
9
10const (
11	UnknownFormat FormatKind = iota
12	JWE
13	JWS
14	JWK
15	JWKS
16	JWT
17)
18
19type formatHint struct {
20	Payload    json.RawMessage `json:"payload"`    // Only in JWS
21	Signatures json.RawMessage `json:"signatures"` // Only in JWS
22	Ciphertext json.RawMessage `json:"ciphertext"` // Only in JWE
23	KeyType    json.RawMessage `json:"kty"`        // Only in JWK
24	Keys       json.RawMessage `json:"keys"`       // Only in JWKS
25	Audience   json.RawMessage `json:"aud"`        // Only in JWT
26}
27
28// GuessFormat is used to guess the format the given payload is in
29// using heuristics. See the type FormatKind for a full list of
30// possible types.
31//
32// This may be useful in determining your next action when you may
33// encounter a payload that could either be a JWE, JWS, or a plain JWT.
34//
35// Because JWTs are almost always JWS signed, you may be thrown off
36// if you pass what you think is a JWT payload to this function.
37// If the function is in the "Compact" format, it means it's a JWS
38// signed message, and its payload is the JWT. Therefore this function
39// will reuturn JWS, not JWT.
40//
41// This function requires an extra parsing of the payload, and therefore
42// may be inefficient if you call it every time before parsing.
43func GuessFormat(payload []byte) FormatKind {
44	// The check against kty, keys, and aud are something this library
45	// made up. for the distinctions between JWE and JWS, we used
46	// https://datatracker.ietf.org/doc/html/rfc7516#section-9.
47	//
48	// The above RFC described several ways to distinguish between
49	// a JWE and JWS JSON, but we're only using one of them
50
51	payload = bytes.TrimSpace(payload)
52	if len(payload) <= 0 {
53		return UnknownFormat
54	}
55
56	if payload[0] != '{' {
57		// Compact format. It's probably a JWS or JWE
58		sep := []byte{'.'} // I want to const this :/
59
60		// Note: this counts the number of occurrences of the
61		// separator, but the RFC talks about the number of segments.
62		// number of '.' == segments - 1, so that's why we have 2 and 4 here
63		switch count := bytes.Count(payload, sep); count {
64		case 2:
65			return JWS
66		case 4:
67			return JWE
68		default:
69			return UnknownFormat
70		}
71	}
72
73	// If we got here, we probably have JSON.
74	var h formatHint
75	if err := json.Unmarshal(payload, &h); err != nil {
76		return UnknownFormat
77	}
78
79	if h.Audience != nil {
80		return JWT
81	}
82	if h.KeyType != nil {
83		return JWK
84	}
85	if h.Keys != nil {
86		return JWKS
87	}
88	if h.Ciphertext != nil {
89		return JWE
90	}
91	if h.Signatures != nil && h.Payload != nil {
92		return JWS
93	}
94	return UnknownFormat
95}
96