reid před 2 roky
rodič
revize
23481c669f
3 změnil soubory, kde provedl 67 přidání a 55 odebrání
  1. 10 10
      auth/auth.go
  2. 6 6
      broadcast/broadcast.go
  3. 51 39
      ws/ws.go

+ 10 - 10
auth/auth.go

@@ -43,8 +43,8 @@ var (
 
 // check if websocket-token pair is auth'd
 func WsIsAuthenticated(conn *websocket.Conn, token string) bool {
-	AuthenticatedClients.RLock()         // Acquire read lock
-	defer AuthenticatedClients.RUnlock() // Release read lock
+	AuthenticatedClients.RLock()
+	defer AuthenticatedClients.RUnlock()
 	if AuthenticatedClients.Conns[token] == conn {
 		return true
 	} else {
@@ -54,8 +54,8 @@ func WsIsAuthenticated(conn *websocket.Conn, token string) bool {
 
 // quick check if websocket is authed at all for unauth broadcast (not for actual auth)
 func WsAuthCheck(conn *websocket.Conn) bool {
-	AuthenticatedClients.RLock()         // Acquire read lock
-	defer AuthenticatedClients.RUnlock() // Release read lock
+	AuthenticatedClients.RLock()
+	defer AuthenticatedClients.RUnlock()
 	for _, con := range AuthenticatedClients.Conns {
 		if con == conn {
 			return true
@@ -98,14 +98,14 @@ func AddToAuthMap(conn *websocket.Conn, token map[string]string, authed bool) er
 }
 
 // check the validity of the token
-func CheckToken(token map[string]string, conn *websocket.Conn, r *http.Request, setup bool) (bool, string, error) {
+func CheckToken(token map[string]string, conn *websocket.Conn, r *http.Request, setup bool) bool {
 	// great you have token. we see if valid.
 	conf := config.Conf()
 	key := conf.KeyFile
 	res, err := KeyfileDecrypt(token["token"], key)
 	if err != nil {
 		config.Logger.Warn("Invalid token provided")
-		return false, "", err
+		return false
 	} else {
 		// so you decrypt. now we see the useragent and ip.
 		var ip string
@@ -120,17 +120,17 @@ func CheckToken(token map[string]string, conn *websocket.Conn, r *http.Request,
 		// you in auth map?
 		if WsIsAuthenticated(conn, hash) {
 			if ip == res["ip"] && userAgent == res["user_agent"] {
-				return true, res["id"], nil
+				return true
 			} else {
 				config.Logger.Warn("Token doesn't match session!")
-				return false, res["id"], err
+				return false
 			}
 		} else {
 			config.Logger.Warn("Token isn't an authenticated session")
-			return false, res["id"], err
+			return false
 		}
 	}
-	return false, "", nil
+	return false
 }
 
 // create a new session token

+ 6 - 6
broadcast/broadcast.go

@@ -351,12 +351,12 @@ func BroadcastToClients() error {
 		}
 	}
 	// for debug, remove me
-	for client := range clients {
-		if err := client.WriteMessage(websocket.TextMessage, authJson); err != nil {
-			config.Logger.Error(fmt.Sprintf("Error writing response: %v", err))
-			return err
-		}
-	}
+	// for client := range clients {
+	// 	if err := client.WriteMessage(websocket.TextMessage, authJson); err != nil {
+	// 		config.Logger.Error(fmt.Sprintf("Error writing response: %v", err))
+	// 		return err
+	// 	}
+	// }
 	return nil
 }
 

+ 51 - 39
ws/ws.go

@@ -44,6 +44,7 @@ var (
 
 // switch on ws event cases
 func WsHandler(w http.ResponseWriter, r *http.Request) {
+	conf := config.Conf()
 	conn, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		config.Logger.Error(fmt.Sprintf("Couldn't upgrade websocket connection: %v", err))
@@ -105,38 +106,50 @@ func WsHandler(w http.ResponseWriter, r *http.Request) {
 			continue
 		}
 		payload.Payload = structs.WsLoginPayload{}
-		switch payload.Type {
-		case "login":
-			if err = loginHandler(conn, msg, payload); err != nil {
-				config.Logger.Error(fmt.Sprintf("%v", err))
-			}
-		case "setup":
-			config.Logger.Info("Setup")
-			// setup.Setup(payload)
-		case "new_ship":
-			config.Logger.Info("New ship")
-		case "pier_upload":
-			config.Logger.Info("Pier upload")
-		case "password":
-			config.Logger.Info("Password")
-		case "system":
-			config.Logger.Info("System")
-		case "startram":
-			config.Logger.Info("StarTram")
-		case "urbit":
-			config.Logger.Info("Urbit")
-		case "support":
-			if err = supportHandler(msg, payload, r, conn); err != nil {
-				config.Logger.Error(fmt.Sprintf("%v", err))
+		token := map[string]string{
+			"id":    payload.Token.ID,
+			"token": payload.Token.Token,
+		}
+		if auth.CheckToken(token, conn, r, conf.FirstBoot) {
+			switch payload.Type {
+			case "new_ship":
+				config.Logger.Info("New ship")
+			case "pier_upload":
+				config.Logger.Info("Pier upload")
+			case "password":
+				config.Logger.Info("Password")
+			case "system":
+				config.Logger.Info("System")
+			case "startram":
+				config.Logger.Info("StarTram")
+			case "urbit":
+				config.Logger.Info("Urbit")
+			case "support":
+				if err = supportHandler(msg, payload, r, conn); err != nil {
+					config.Logger.Error(fmt.Sprintf("%v", err))
+				}
+			case "broadcast":
+				if err := broadcast.BroadcastToClients(); err != nil {
+					errmsg := fmt.Sprintf("Unable to broadcast to peer(s): %v", err)
+					config.Logger.Error(errmsg)
+				}
+			default:
+				errmsg := fmt.Sprintf("Unknown request type: %s", payload.Type)
+				config.Logger.Warn(errmsg)
 			}
-		case "broadcast":
-			if err := broadcast.BroadcastToClients(); err != nil {
-				errmsg := fmt.Sprintf("Unable to broadcast to peer(s): %v", err)
-				config.Logger.Error(errmsg)
+		} else {
+			switch payload.Type {
+			case "login":
+				if err = loginHandler(conn, msg, payload); err != nil {
+					config.Logger.Error(fmt.Sprintf("%v", err))
+				}
+			case "setup":
+				config.Logger.Info("Setup")
+				// setup.Setup(payload)
+			default:
+				errmsg := fmt.Sprintf("Unknown request type: %s", payload.Type)
+				config.Logger.Warn(errmsg)
 			}
-		default:
-			errmsg := fmt.Sprintf("Unknown request type: %s", payload.Type)
-			config.Logger.Warn(errmsg)
 		}
 	}
 	// default to unauth
@@ -172,18 +185,17 @@ func loginHandler(conn *websocket.Conn, msg []byte, payload structs.WsPayload) e
 	return nil
 }
 
-
 // broadcast the unauth payload
 func unauthHandler(conn *websocket.Conn, r *http.Request) {
 	blob := structs.UnauthBroadcast{
-        Type:      "structure",
-        AuthLevel: "unauthorized",
-        Login: struct {
-            Remainder int `json:"remainder"`
-        }{
-            Remainder: 0,
-        },
-    }
+		Type:      "structure",
+		AuthLevel: "unauthorized",
+		Login: struct {
+			Remainder int `json:"remainder"`
+		}{
+			Remainder: 0,
+		},
+	}
 	resp, err := json.Marshal(blob)
 	if err != nil {
 		config.Logger.Error(fmt.Sprintf("Error unmarshalling message: %v", err))