access_test.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. /*
  2. * EliasDB
  3. *
  4. * Copyright 2016 Matthias Ladkau. All rights reserved.
  5. *
  6. * This Source Code Form is subject to the terms of the Mozilla Public
  7. * License, v. 2.0. If a copy of the MPL was not distributed with this
  8. * file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. */
  10. package ac
  11. import (
  12. "bytes"
  13. "encoding/json"
  14. "flag"
  15. "fmt"
  16. "io/ioutil"
  17. "net/http"
  18. "os"
  19. "strings"
  20. "sync"
  21. "testing"
  22. "devt.de/krotik/common/datautil"
  23. "devt.de/krotik/common/errorutil"
  24. "devt.de/krotik/common/httputil"
  25. "devt.de/krotik/common/httputil/access"
  26. "devt.de/krotik/common/httputil/auth"
  27. "devt.de/krotik/common/stringutil"
  28. "devt.de/krotik/eliasdb/api"
  29. )
  30. const TESTPORT = ":9090"
  31. // Main function for all tests in this package
  32. func TestMain(m *testing.M) {
  33. var err error
  34. flag.Parse()
  35. hs, wg := startServer()
  36. if hs == nil {
  37. return
  38. }
  39. // Disable access logging
  40. LogAccess = func(v ...interface{}) {}
  41. // Register public endpoints
  42. api.RegisterRestEndpoints(PublicAccessControlEndpointMap)
  43. // Initialise auth handler
  44. AuthHandler = auth.NewCookieAuthHandleFuncWrapper(http.HandleFunc)
  45. // Important statement! - all registered endpoints afterwards
  46. // are subject to access control
  47. api.HandleFunc = AuthHandler.HandleFunc
  48. // Register management endpoints
  49. api.RegisterRestEndpoints(AccessManagementEndpointMap)
  50. // Register dummy page
  51. api.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) {
  52. w.Write([]byte("foobar!"))
  53. })
  54. // Initialise user DB
  55. UserDB, err = datautil.NewEnforcedUserDB("test_user.db", "")
  56. errorutil.AssertOk(err)
  57. // Put the UserDB in charge of verifying passwords
  58. AuthHandler.SetAuthFunc(UserDB.CheckUserPassword)
  59. // Initialise ACL's
  60. var conf map[string]interface{}
  61. errorutil.AssertOk(json.Unmarshal(stringutil.StripCStyleComments(DefaultAccessDB), &conf))
  62. at, err := access.NewMemoryACLTableFromConfig(conf)
  63. errorutil.AssertOk(err)
  64. InitACLs(at)
  65. // Connect the ACL object to the AuthHandler - this provides authorization for users
  66. AuthHandler.SetAccessFunc(ACL.CheckHTTPRequest)
  67. // Adding special handlers which redirect to the login page
  68. AuthHandler.CallbackSessionExpired = CallbackSessionExpired
  69. AuthHandler.CallbackUnauthorized = CallbackUnauthorized
  70. // Add users
  71. UserDB.UserDB.AddUserEntry("elias", "elias", nil)
  72. UserDB.UserDB.AddUserEntry("johndoe", "doe", nil)
  73. UserDB.UserDB.AddUserEntry("guest", "g", nil)
  74. // Disable debounce time for unit tests
  75. DebounceTime = 0
  76. // Run the tests
  77. res := m.Run()
  78. // Teardown
  79. stopServer(hs, wg)
  80. // Stop ACL monitoring
  81. ACL.Close()
  82. // Remove files
  83. os.Remove("test_user.db")
  84. os.Exit(res)
  85. }
  86. func TestSwaggerDefs(t *testing.T) {
  87. // Test we can build swagger defs from the endpoint
  88. data := map[string]interface{}{
  89. "paths": map[string]interface{}{},
  90. "definitions": map[string]interface{}{},
  91. }
  92. for _, inst := range PublicAccessControlEndpointMap {
  93. inst().SwaggerDefs(data)
  94. }
  95. for _, inst := range AccessManagementEndpointMap {
  96. inst().SwaggerDefs(data)
  97. }
  98. }
  99. /*
  100. Start a HTTP test server.
  101. */
  102. func startServer() (*httputil.HTTPServer, *sync.WaitGroup) {
  103. hs := &httputil.HTTPServer{}
  104. var wg sync.WaitGroup
  105. wg.Add(1)
  106. go hs.RunHTTPServer(TESTPORT, &wg)
  107. wg.Wait()
  108. // Server is started
  109. if hs.LastError != nil {
  110. panic(hs.LastError)
  111. }
  112. return hs, &wg
  113. }
  114. /*
  115. Stop a started HTTP test server.
  116. */
  117. func stopServer(hs *httputil.HTTPServer, wg *sync.WaitGroup) {
  118. if hs.Running == true {
  119. wg.Add(1)
  120. // Server is shut down
  121. hs.Shutdown()
  122. wg.Wait()
  123. } else {
  124. panic("Server was not running as expected")
  125. }
  126. }
  127. /*
  128. Send a request to a HTTP test server
  129. */
  130. func sendTestRequest(contentType string, url string, method string, content []byte,
  131. reqMod func(*http.Request)) string {
  132. body, _ := sendTestRequestResponse(contentType, url, method, content, reqMod)
  133. return body
  134. }
  135. /*
  136. Send a request to a HTTP test server
  137. */
  138. func sendTestRequestResponse(contentType string, url string, method string,
  139. content []byte, reqMod func(*http.Request)) (string, *http.Response) {
  140. var req *http.Request
  141. var err error
  142. if content != nil {
  143. req, err = http.NewRequest(method, url, bytes.NewBuffer(content))
  144. } else {
  145. req, err = http.NewRequest(method, url, nil)
  146. }
  147. if err != nil {
  148. panic(err)
  149. }
  150. req.Header.Set("Content-Type", contentType)
  151. if reqMod != nil {
  152. reqMod(req)
  153. }
  154. client := &http.Client{}
  155. resp, err := client.Do(req)
  156. if err != nil {
  157. panic(err)
  158. }
  159. defer resp.Body.Close()
  160. body, _ := ioutil.ReadAll(resp.Body)
  161. bodyStr := strings.Trim(string(body), " \n")
  162. // Try json decoding first
  163. out := bytes.Buffer{}
  164. err = json.Indent(&out, []byte(bodyStr), "", " ")
  165. if err == nil {
  166. return out.String(), resp
  167. }
  168. // Just return the body
  169. return bodyStr, resp
  170. }
  171. /*
  172. Perform authentication and retrieve an auth cookie
  173. */
  174. func doAuth(user, pass string) *http.Cookie {
  175. queryURL := "http://localhost" + TESTPORT
  176. // Send authentication request with correct credentials
  177. res, resp := sendTestRequestResponse("application/json", queryURL+EndpointLogin, "POST", []byte(`
  178. {
  179. "user" : "`+user+`",
  180. "pass" : "`+pass+`"
  181. }
  182. `), nil)
  183. errorutil.AssertTrue(len(resp.Cookies()) > 0, res)
  184. // Right after authentication we only have the authentication cookie - after
  185. // the first visit to a non-public page we will also have a session cookie
  186. authCookie := resp.Cookies()[0]
  187. errorutil.AssertTrue(authCookie.Name == "~aid",
  188. fmt.Sprint("Unexpected name for cookie:", authCookie))
  189. return authCookie
  190. }