reid 2 gadi atpakaļ
vecāks
revīzija
0f1987918b
4 mainītis faili ar 100 papildinājumiem un 23 dzēšanām
  1. 36 5
      auth/auth.go
  2. 13 8
      broadcast/broadcast.go
  3. 0 6
      structs/structs.go
  4. 51 4
      ws/ws.go

+ 36 - 5
auth/auth.go

@@ -9,18 +9,29 @@ import (
 	"goseg/config"
 	"fmt"
 	"time"
+	"sync"
 
+	"github.com/gorilla/websocket"
 	"golang.org/x/crypto/nacl/secretbox"
 )
 
+var (
+	AuthenticatedClients = struct{
+		Conns map[*websocket.Conn]bool
+		sync.Mutex
+	}{
+		Conns: make(map[*websocket.Conn]bool),
+	}
+)
+
 // CheckToken checks the validity of the token.
-func CheckToken(token string, websocket interface{}, setup bool) (bool, string, error) {
+func CheckToken(token string, conn *websocket.Conn, setup bool) (bool, string, error) {
 	if token == "" {
 		authStatus := false
 		if setup {
 			authStatus = true
 		}
-		newToken, err := CreateToken(websocket, setup)
+		newToken, err := CreateToken(conn, setup)
 		if err != nil {
 			return false, "", err
 		}
@@ -34,7 +45,7 @@ func CheckToken(token string, websocket interface{}, setup bool) (bool, string,
 }
 
 // CreateToken creates a new session token.
-func CreateToken(websocket interface{}, setup bool) (map[string]string, error) {
+func CreateToken(conn *websocket.Conn, setup bool) (map[string]string, error) {
 	// placeholder logic for IP and UserAgent
 	ip := "localhost"
 	userAgent := "lick"
@@ -148,6 +159,26 @@ func NewSecretString(length int) string {
 	return base64.URLEncoding.EncodeToString(randBytes)
 }
 
-func auth() {
-	fmt.Println("Compiled")
+func Hasher(password string) string {
+    conf := config.Conf()
+    salt := conf.Salt
+    toHash := password + salt
+    res := sha256.Sum256([]byte(toHash))
+    return hex.EncodeToString(res[:])
+}
+
+func AuthenticateLogin(password string) bool {
+	conf := config.Conf()
+	if Hasher(password) == conf.PwHash {
+		return true
+	} else {
+		return false
+	}
+}
+
+func WsIsAuthenticated(conn *websocket.Conn) bool {
+    AuthenticatedClients.Lock()
+    _, exists := AuthenticatedClients.Conns[conn]
+    AuthenticatedClients.Unlock()
+    return exists
 }

+ 13 - 8
broadcast/broadcast.go

@@ -3,6 +3,7 @@ package broadcast
 import (
 	"encoding/json"
 	"fmt"
+	"goseg/auth"
 	"goseg/config"
 	"goseg/docker"
 	"goseg/startram"
@@ -22,6 +23,7 @@ var (
 	logger         = slog.New(slog.NewJSONHandler(os.Stdout, nil))
 	clients        = make(map[*websocket.Conn]bool)
 	broadcastState structs.AuthBroadcast
+	unauthState    structs.UnauthBroadcast
 	mu             sync.RWMutex // synchronize access to broadcastState
 )
 
@@ -323,17 +325,20 @@ func GetStateJson() ([]byte, error) {
 	return broadcastJson, nil
 }
 
-// broadcast the global state to all clients
+// broadcast the global state to auth'd clients
 func BroadcastToClients() error {
-	broadcastJson, err := json.Marshal(broadcastState)
+	authJson, err := json.Marshal(broadcastState)
 	if err != nil {
-		logger.Error("Error marshalling response:", err)
-		return err
+		errmsg := fmt.Errorf("Error marshalling auth broadcast:", err)
+		return errmsg
 	}
-	for client := range clients {
-		if err := client.WriteMessage(websocket.TextMessage, broadcastJson); err != nil {
-			logger.Error("Error writing response:", err)
-			return err
+	for client := range clients { 
+		_, authenticated := auth.AuthenticatedClients.Conns[client]
+		if authenticated {
+			if err := client.WriteMessage(websocket.TextMessage, authJson); err != nil {
+				logger.Error("Error writing response:", err)
+				return err
+			}
 		}
 	}
 	return nil

+ 0 - 6
structs/structs.go

@@ -4,12 +4,6 @@ import (
 	"github.com/docker/docker/api/types/container"
 )
 
-// incoming websocket payloads
-type WsPayload struct {
-	Type   string `json:"type"`
-	Action string `json:"action"`
-}
-
 // eventbus event payloads
 type Event struct {
 	Type string

+ 51 - 4
ws/ws.go

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"github.com/gorilla/websocket"
+	"goseg/auth"
 	"goseg/broadcast"
 	"goseg/structs"
 	"log/slog"
@@ -20,6 +21,26 @@ var (
 	}
 )
 
+// func handleConnection(c *websocket.Conn) {
+//     // Read the first message from the client which should be the token
+//     messageType, p, err := c.ReadMessage()
+//     if err != nil {
+//         logger.Error(fmt.Errorf("%v",err))
+//         return
+//     }
+//     token := string(p)
+//     // Verify the token
+//     isValid, _, err := CheckToken(token, c, false)  // 'false' assumes it's not a setup
+//     if !isValid || err != nil {
+//         logger.Info("Invalid token provided by client.")
+//         c.Close()
+//         return
+//     }
+
+//     // rest of logic
+// }
+
+
 // switch on ws event cases
 func WsHandler(w http.ResponseWriter, r *http.Request) {
 	conn, err := upgrader.Upgrade(w, r, nil)
@@ -53,12 +74,38 @@ func WsHandler(w http.ResponseWriter, r *http.Request) {
 		if err != nil {
 			return
 		}
-		var payload structs.WsPayload
-		if err := json.Unmarshal(msg, &payload); err != nil {
+		var prelim structs.WsType
+		if err := json.Unmarshal(msg, &prelim); err != nil {
 			fmt.Println("Error unmarshalling message:", err)
 			continue
 		}
-		switch payload.Type {
+		switch prelim.Payload.Type {
+		case "login":
+			var payload structs.WsPayload
+			payload.Payload = structs.LoginPayload{}
+			if err := json.Unmarshal(msg, &payload); err != nil {
+				fmt.Println("Error unmarshalling message:", err)
+				continue
+			}
+			loginPayload, ok := payload.Payload.(structs.LoginPayload)
+			if !ok {
+				fmt.Println("Error casting to LoginPayload")
+				continue
+			}
+			isAuthenticated := auth.AuthenticateLogin(loginPayload.Password)
+			if isAuthenticated {
+				_, err := auth.CreateToken(conn, false)
+				if err != nil {
+					fmt.Println("Error creating token:", err)
+					continue
+				}
+				auth.AuthenticatedClients.Lock()
+				auth.AuthenticatedClients.Conns[conn] = true
+				auth.AuthenticatedClients.Unlock()
+                } else {
+                    // Handle invalid login
+                    fmt.Println("Invalid login attempt")
+                }
 		case "setup":
 			logger.Info("Setup")
 			// setup.Setup(payload)
@@ -82,7 +129,7 @@ func WsHandler(w http.ResponseWriter, r *http.Request) {
 				logger.Error(errmsg)
 			}
 		default:
-			errmsg := fmt.Sprintf("Unknown request type:", payload.Type)
+			errmsg := fmt.Sprintf("Unknown request type: %s", prelim.Payload.Type)
 			logger.Warn(errmsg)
 		}
 	}