cache.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. package validator
  2. import (
  3. "fmt"
  4. "reflect"
  5. "strings"
  6. "sync"
  7. "sync/atomic"
  8. )
  9. type tagType uint8
  10. const (
  11. typeDefault tagType = iota
  12. typeOmitEmpty
  13. typeIsDefault
  14. typeNoStructLevel
  15. typeStructOnly
  16. typeDive
  17. typeOr
  18. typeKeys
  19. typeEndKeys
  20. typeOmitNil
  21. typeOmitZero
  22. )
  23. const (
  24. invalidValidation = "Invalid validation tag on field '%s'"
  25. undefinedValidation = "Undefined validation function '%s' on field '%s'"
  26. keysTagNotDefined = "'" + endKeysTag + "' tag encountered without a corresponding '" + keysTag + "' tag"
  27. )
  28. type structCache struct {
  29. lock sync.Mutex
  30. m atomic.Value // map[reflect.Type]*cStruct
  31. }
  32. func (sc *structCache) Get(key reflect.Type) (c *cStruct, found bool) {
  33. c, found = sc.m.Load().(map[reflect.Type]*cStruct)[key]
  34. return
  35. }
  36. func (sc *structCache) Set(key reflect.Type, value *cStruct) {
  37. m := sc.m.Load().(map[reflect.Type]*cStruct)
  38. nm := make(map[reflect.Type]*cStruct, len(m)+1)
  39. for k, v := range m {
  40. nm[k] = v
  41. }
  42. nm[key] = value
  43. sc.m.Store(nm)
  44. }
  45. type tagCache struct {
  46. lock sync.Mutex
  47. m atomic.Value // map[string]*cTag
  48. }
  49. func (tc *tagCache) Get(key string) (c *cTag, found bool) {
  50. c, found = tc.m.Load().(map[string]*cTag)[key]
  51. return
  52. }
  53. func (tc *tagCache) Set(key string, value *cTag) {
  54. m := tc.m.Load().(map[string]*cTag)
  55. nm := make(map[string]*cTag, len(m)+1)
  56. for k, v := range m {
  57. nm[k] = v
  58. }
  59. nm[key] = value
  60. tc.m.Store(nm)
  61. }
  62. type cStruct struct {
  63. name string
  64. fields []*cField
  65. fn StructLevelFuncCtx
  66. }
  67. type cField struct {
  68. idx int
  69. name string
  70. altName string
  71. namesEqual bool
  72. cTags *cTag
  73. }
  74. type cTag struct {
  75. tag string
  76. aliasTag string
  77. actualAliasTag string
  78. param string
  79. keys *cTag // only populated when using tag's 'keys' and 'endkeys' for map key validation
  80. next *cTag
  81. fn FuncCtx
  82. typeof tagType
  83. hasTag bool
  84. hasAlias bool
  85. hasParam bool // true if parameter used eg. eq= where the equal sign has been set
  86. isBlockEnd bool // indicates the current tag represents the last validation in the block
  87. runValidationWhenNil bool
  88. }
  89. func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStruct {
  90. v.structCache.lock.Lock()
  91. defer v.structCache.lock.Unlock() // leave as defer! because if inner panics, it will never get unlocked otherwise!
  92. typ := current.Type()
  93. // could have been multiple trying to access, but once first is done this ensures struct
  94. // isn't parsed again.
  95. cs, ok := v.structCache.Get(typ)
  96. if ok {
  97. return cs
  98. }
  99. cs = &cStruct{name: sName, fields: make([]*cField, 0), fn: v.structLevelFuncs[typ]}
  100. numFields := current.NumField()
  101. rules := v.rules[typ]
  102. var ctag *cTag
  103. var fld reflect.StructField
  104. var tag string
  105. var customName string
  106. for i := 0; i < numFields; i++ {
  107. fld = typ.Field(i)
  108. if !v.privateFieldValidation && !fld.Anonymous && len(fld.PkgPath) > 0 {
  109. continue
  110. }
  111. if rtag, ok := rules[fld.Name]; ok {
  112. tag = rtag
  113. } else {
  114. tag = fld.Tag.Get(v.tagName)
  115. }
  116. if tag == skipValidationTag {
  117. continue
  118. }
  119. customName = fld.Name
  120. if v.hasTagNameFunc {
  121. name := v.tagNameFunc(fld)
  122. if len(name) > 0 {
  123. customName = name
  124. }
  125. }
  126. // NOTE: cannot use shared tag cache, because tags may be equal, but things like alias may be different
  127. // and so only struct level caching can be used instead of combined with Field tag caching
  128. if len(tag) > 0 {
  129. ctag, _ = v.parseFieldTagsRecursive(tag, fld.Name, "", false)
  130. } else {
  131. // even if field doesn't have validations need cTag for traversing to potential inner/nested
  132. // elements of the field.
  133. ctag = new(cTag)
  134. }
  135. cs.fields = append(cs.fields, &cField{
  136. idx: i,
  137. name: fld.Name,
  138. altName: customName,
  139. cTags: ctag,
  140. namesEqual: fld.Name == customName,
  141. })
  142. }
  143. v.structCache.Set(typ, cs)
  144. return cs
  145. }
  146. func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias string, hasAlias bool) (firstCtag *cTag, current *cTag) {
  147. var t string
  148. noAlias := len(alias) == 0
  149. tags := strings.Split(tag, tagSeparator)
  150. for i := 0; i < len(tags); i++ {
  151. t = tags[i]
  152. if noAlias {
  153. alias = t
  154. }
  155. // check map for alias and process new tags, otherwise process as usual
  156. if tagsVal, found := v.aliases[t]; found {
  157. if i == 0 {
  158. firstCtag, current = v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
  159. } else {
  160. next, curr := v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
  161. current.next, current = next, curr
  162. }
  163. continue
  164. }
  165. var prevTag tagType
  166. if i == 0 {
  167. current = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true, typeof: typeDefault}
  168. firstCtag = current
  169. } else {
  170. prevTag = current.typeof
  171. current.next = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true}
  172. current = current.next
  173. }
  174. switch t {
  175. case diveTag:
  176. current.typeof = typeDive
  177. continue
  178. case keysTag:
  179. current.typeof = typeKeys
  180. if i == 0 || prevTag != typeDive {
  181. panic(fmt.Sprintf("'%s' tag must be immediately preceded by the '%s' tag", keysTag, diveTag))
  182. }
  183. current.typeof = typeKeys
  184. // need to pass along only keys tag
  185. // need to increment i to skip over the keys tags
  186. b := make([]byte, 0, 64)
  187. i++
  188. for ; i < len(tags); i++ {
  189. b = append(b, tags[i]...)
  190. b = append(b, ',')
  191. if tags[i] == endKeysTag {
  192. break
  193. }
  194. }
  195. current.keys, _ = v.parseFieldTagsRecursive(string(b[:len(b)-1]), fieldName, "", false)
  196. continue
  197. case endKeysTag:
  198. current.typeof = typeEndKeys
  199. // if there are more in tags then there was no keysTag defined
  200. // and an error should be thrown
  201. if i != len(tags)-1 {
  202. panic(keysTagNotDefined)
  203. }
  204. return
  205. case omitzero:
  206. current.typeof = typeOmitZero
  207. continue
  208. case omitempty:
  209. current.typeof = typeOmitEmpty
  210. continue
  211. case omitnil:
  212. current.typeof = typeOmitNil
  213. continue
  214. case structOnlyTag:
  215. current.typeof = typeStructOnly
  216. continue
  217. case noStructLevelTag:
  218. current.typeof = typeNoStructLevel
  219. continue
  220. default:
  221. if t == isdefault {
  222. current.typeof = typeIsDefault
  223. }
  224. // if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C"
  225. orVals := strings.Split(t, orSeparator)
  226. for j := 0; j < len(orVals); j++ {
  227. vals := strings.SplitN(orVals[j], tagKeySeparator, 2)
  228. if noAlias {
  229. alias = vals[0]
  230. current.aliasTag = alias
  231. } else {
  232. current.actualAliasTag = t
  233. }
  234. if j > 0 {
  235. current.next = &cTag{aliasTag: alias, actualAliasTag: current.actualAliasTag, hasAlias: hasAlias, hasTag: true}
  236. current = current.next
  237. }
  238. current.hasParam = len(vals) > 1
  239. current.tag = vals[0]
  240. if len(current.tag) == 0 {
  241. panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, fieldName)))
  242. }
  243. if wrapper, ok := v.validations[current.tag]; ok {
  244. current.fn = wrapper.fn
  245. current.runValidationWhenNil = wrapper.runValidationOnNil
  246. } else {
  247. panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, current.tag, fieldName)))
  248. }
  249. if len(orVals) > 1 {
  250. current.typeof = typeOr
  251. }
  252. if len(vals) > 1 {
  253. current.param = strings.ReplaceAll(strings.ReplaceAll(vals[1], utf8HexComma, ","), utf8Pipe, "|")
  254. }
  255. }
  256. current.isBlockEnd = true
  257. }
  258. }
  259. return
  260. }
  261. func (v *Validate) fetchCacheTag(tag string) *cTag {
  262. // find cached tag
  263. ctag, found := v.tagCache.Get(tag)
  264. if !found {
  265. v.tagCache.lock.Lock()
  266. defer v.tagCache.lock.Unlock()
  267. // could have been multiple trying to access, but once first is done this ensures tag
  268. // isn't parsed again.
  269. ctag, found = v.tagCache.Get(tag)
  270. if !found {
  271. ctag, _ = v.parseFieldTagsRecursive(tag, "", "", false)
  272. v.tagCache.Set(tag, ctag)
  273. }
  274. }
  275. return ctag
  276. }