Просмотр исходного кода

additional work on auth and websockets

reid 2 лет назад
Родитель
Сommit
c6e51c85ab
5 измененных файлов с 211 добавлено и 131 удалено
  1. 106 86
      auth/auth.go
  2. 33 14
      broadcast/broadcast.go
  3. 15 0
      config/config.go
  4. 6 6
      structs/ws.go
  5. 51 25
      ws/ws.go

+ 106 - 86
auth/auth.go

@@ -6,56 +6,77 @@ import (
 	"encoding/base64"
 	"encoding/hex"
 	"encoding/json"
-	"goseg/config"
 	"fmt"
-	"time"
+	"goseg/config"
+	"goseg/structs"
+	"net"
+	"net/http"
+	"strings"
 	"sync"
+	"time"
 
 	"github.com/gorilla/websocket"
 	"golang.org/x/crypto/nacl/secretbox"
 )
 
 var (
-	AuthenticatedClients = struct{
-		Conns map[*websocket.Conn]bool
+	// maps a websocket conn to a tokenid
+	// tokenid's can be referenced from the global conf
+	AuthenticatedClients = struct {
+		Conns map[*websocket.Conn]string
+		sync.Mutex
+	}{
+		Conns: make(map[*websocket.Conn]string),
+	}
+	UnauthClients = struct {
+		Conns map[*websocket.Conn]string
 		sync.Mutex
 	}{
-		Conns: make(map[*websocket.Conn]bool),
+		Conns: make(map[*websocket.Conn]string),
 	}
 )
 
-// CheckToken checks the validity of the token.
-func CheckToken(token string, conn *websocket.Conn, setup bool) (bool, string, error) {
-	if token == "" {
-		authStatus := false
-		if setup {
-			authStatus = true
-		}
-		newToken, err := CreateToken(conn, setup)
-		if err != nil {
-			return false, "", err
-		}
-		return authStatus, newToken["token"], nil
-	}
+// check if websocket session is in the auth map
+func WsIsAuthenticated(conn *websocket.Conn) bool {
+	AuthenticatedClients.Lock()
+	_, exists := AuthenticatedClients.Conns[conn]
+	AuthenticatedClients.Unlock()
+	return exists
+}
 
-	// Token verification logic here...
-	// ...
+// // CheckToken checks the validity of the token.
+// func CheckToken(token string, conn *websocket.Conn, setup bool) (bool, string, error) {
+// 	if token == "" {
+// 		authStatus := false
+// 		if setup {
+// 			authStatus = true
+// 		}
+// 		newToken, err := CreateToken(conn, setup)
+// 		if err != nil {
+// 			return false, "", err
+// 		}
+// 		return authStatus, newToken["token"], nil
+// 	}
 
-	return false, "", nil
-}
+// 	return false, "", nil
+// }
 
 // CreateToken creates a new session token.
-func CreateToken(conn *websocket.Conn, setup bool) (map[string]string, error) {
-	// placeholder logic for IP and UserAgent
-	ip := "localhost"
-	userAgent := "lick"
-	// if more properties are needed from websocket, you can extract them here
+func CreateToken(conn *websocket.Conn, r *http.Request, setup bool) (map[string]string, error) {
+	// extract conn info
+	var ip string
+	if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
+		ip = strings.Split(forwarded, ",")[0]
+	} else {
+		ip, _, _ = net.SplitHostPort(r.RemoteAddr)
+	}
+	userAgent := r.Header.Get("User-Agent")
 	conf := config.Conf()
-	key := conf.KeyFile
+	now := time.Now().Format("2006-01-02_15:04:05")
 	// generate random strings for id, secret, and padding
-	id := NewSecretString(32)
-	secret := NewSecretString(128)
-	padding := NewSecretString(32)
+	id := config.RandString(32)
+	secret := config.RandString(128)
+	padding := config.RandString(32)
 	contents := map[string]string{
 		"id":         id,
 		"ip":         ip,
@@ -63,42 +84,56 @@ func CreateToken(conn *websocket.Conn, setup bool) (map[string]string, error) {
 		"secret":     secret,
 		"padding":    padding,
 		"authorized": fmt.Sprintf("%v", setup),
-		"created":    time.Now().Format("2006-01-02_15:04:05"),
+		"created":    now,
 	}
 	// encrypt the contents
+	key := conf.KeyFile
 	encryptedText, err := KeyfileEncrypt(contents, key)
 	if err != nil {
 		return nil, fmt.Errorf("failed to encrypt token: %v", err)
 	}
-	// and hash
 	hashed := sha256.Sum256([]byte(encryptedText))
+	hash := hex.EncodeToString(hashed[:])
 	// Update sessions in the system's configuration
-	now := time.Now().Format("2006-01-02_15:04:05")
-	if setup {
-		updates := map[string]interface{}{
+	AddSession(id, hash, now, setup)
+	return map[string]string{
+		"id":    id,
+		"token": encryptedText,
+	}, nil
+}
+
+// take session details and add to SysConfig
+func AddSession(tokenID string, hash string, created string, authorized bool) error {
+	session := structs.SessionInfo{
+		Hash:    hash,
+		Created: created,
+	}
+	if authorized {
+		update := map[string]interface{}{
 			"Sessions": map[string]interface{}{
 				"Authorized": map[string]string{
-					"Hash":    hex.EncodeToString(hashed[:]),
-					"Created": now,
+					"Hash":    session.Hash,
+					"Created": session.Created,
 				},
 			},
 		}
-		config.UpdateConf(updates)
+		if err := config.UpdateConf(update); err != nil {
+			return fmt.Errorf("Error adding session: %v", err)
+		}
 	} else {
-		updates := map[string]interface{}{
+		update := map[string]interface{}{
 			"Sessions": map[string]interface{}{
 				"Unauthorized": map[string]string{
-					"Hash":    hex.EncodeToString(hashed[:]),
-					"Created": now,
+					"Hash":    session.Hash,
+					"Created": session.Created,
 				},
 			},
 		}
-		config.UpdateConf(updates)
+		if err := config.UpdateConf(update); err != nil {
+			return fmt.Errorf("Error adding session: %v", err)
+		}
 	}
-	return map[string]string{
-		"id":    id,
-		"token": encryptedText,
-	}, nil
+	return nil
 }
 
 // encrypt the contents using stored keyfile val
@@ -127,46 +162,38 @@ func KeyfileEncrypt(contents map[string]string, key string) (string, error) {
 // decrypt routine
 func KeyfileDecrypt(encryptedText string, key string) (map[string]string, error) {
 	// get bytes
-    keyBytes := []byte(key)
-    var keyArray [32]byte
-    copy(keyArray[:], keyBytes)
-    encryptedBytes, err := base64.URLEncoding.DecodeString(encryptedText)
-    if err != nil {
-        return nil, err
-    }
+	keyBytes := []byte(key)
+	var keyArray [32]byte
+	copy(keyArray[:], keyBytes)
+	encryptedBytes, err := base64.URLEncoding.DecodeString(encryptedText)
+	if err != nil {
+		return nil, err
+	}
 	// get nonce
-    var nonce [24]byte
-    copy(nonce[:], encryptedBytes[:24])
+	var nonce [24]byte
+	copy(nonce[:], encryptedBytes[:24])
 	// attempt decrypt
-    decrypted, ok := secretbox.Open(nil, encryptedBytes[24:], &nonce, &keyArray)
-    if !ok {
-        return nil, fmt.Errorf("Decryption failed")
-    }
-    var contents map[string]string
-    if err := json.Unmarshal(decrypted, &contents); err != nil {
-        return nil, err
-    } 
-    return contents, nil
-}
-
-// NewSecretString generates a random secret string of the given length.
-func NewSecretString(length int) string {
-	randBytes := make([]byte, length)
-	_, err := rand.Read(randBytes)
-	if err != nil {
-		return ""
+	decrypted, ok := secretbox.Open(nil, encryptedBytes[24:], &nonce, &keyArray)
+	if !ok {
+		return nil, fmt.Errorf("Decryption failed")
+	}
+	var contents map[string]string
+	if err := json.Unmarshal(decrypted, &contents); err != nil {
+		return nil, err
 	}
-	return base64.URLEncoding.EncodeToString(randBytes)
+	return contents, nil
 }
 
+// salted sha256
 func Hasher(password string) string {
-    conf := config.Conf()
-    salt := conf.Salt
-    toHash := password + salt
-    res := sha256.Sum256([]byte(toHash))
-    return hex.EncodeToString(res[:])
+	conf := config.Conf()
+	salt := conf.Salt
+	toHash := salt + password
+	res := sha256.Sum256([]byte(toHash))
+	return hex.EncodeToString(res[:])
 }
 
+// check if pw matches sysconfig
 func AuthenticateLogin(password string) bool {
 	conf := config.Conf()
 	if Hasher(password) == conf.PwHash {
@@ -175,10 +202,3 @@ func AuthenticateLogin(password string) bool {
 		return false
 	}
 }
-
-func WsIsAuthenticated(conn *websocket.Conn) bool {
-    AuthenticatedClients.Lock()
-    _, exists := AuthenticatedClients.Conns[conn]
-    AuthenticatedClients.Unlock()
-    return exists
-}

+ 33 - 14
broadcast/broadcast.go

@@ -15,16 +15,18 @@ import (
 	"reflect"
 	"strings"
 	"sync"
+	"time"
 
 	"github.com/gorilla/websocket"
 )
 
 var (
-	logger         = slog.New(slog.NewJSONHandler(os.Stdout, nil))
-	clients        = make(map[*websocket.Conn]bool)
-	broadcastState structs.AuthBroadcast
-	unauthState    structs.UnauthBroadcast
-	mu             sync.RWMutex // synchronize access to broadcastState
+	logger           = slog.New(slog.NewJSONHandler(os.Stdout, nil))
+	clients          = make(map[*websocket.Conn]bool)
+	hostInfoInterval = 3 * time.Second // how often we refresh system info
+	broadcastState   structs.AuthBroadcast
+	unauthState      structs.UnauthBroadcast
+	mu               sync.RWMutex // synchronize access to broadcastState
 )
 
 func init() {
@@ -108,6 +110,8 @@ func bootstrapBroadcastState(config structs.SysConfig) (structs.AuthBroadcast, e
 		errmsg := fmt.Sprintf("Error updating broadcast state:", err)
 		logger.Error(errmsg)
 	}
+	// start looping host info
+	go HostStatusLoop()
 	// return the boostrapped result
 	res = GetState()
 	return res, nil
@@ -316,7 +320,8 @@ func GetState() structs.AuthBroadcast {
 func GetStateJson() ([]byte, error) {
 	mu.Lock()
 	defer mu.Unlock()
-	broadcastJson, err := json.Marshal(broadcastState)
+	bState := GetState()
+	broadcastJson, err := json.Marshal(bState)
 	if err != nil {
 		errmsg := fmt.Sprintf("Error marshalling response: %v", err)
 		logger.Error(errmsg)
@@ -327,19 +332,33 @@ func GetStateJson() ([]byte, error) {
 
 // broadcast the global state to auth'd clients
 func BroadcastToClients() error {
-	authJson, err := json.Marshal(broadcastState)
+	authJson, err := GetStateJson()
 	if err != nil {
 		errmsg := fmt.Errorf("Error marshalling auth broadcast:", err)
 		return errmsg
 	}
-	for client := range clients { 
-		_, authenticated := auth.AuthenticatedClients.Conns[client]
-		if authenticated {
-			if err := client.WriteMessage(websocket.TextMessage, authJson); err != nil {
-				logger.Error("Error writing response:", err)
-				return err
-			}
+	auth.AuthenticatedClients.Lock()
+	defer auth.AuthenticatedClients.Unlock()
+	for client := range auth.AuthenticatedClients.Conns {
+		if err := client.WriteMessage(websocket.TextMessage, authJson); err != nil {
+			logger.Error(fmt.Sprintf("Error writing response: %v", err))
+			return err
 		}
 	}
 	return nil
 }
+
+// refresh loop for host info
+func HostStatusLoop() {
+	ticker := time.NewTicker(hostInfoInterval)
+	for {
+		select {
+		case <-ticker.C:
+			update := constructSystemInfo()
+			err := UpdateBroadcastState(update)
+			if err != nil {
+				logger.Warn(fmt.Sprintf("Error updating system status: %v", err))
+			}
+		}
+	}
+}

+ 15 - 0
config/config.go

@@ -3,12 +3,14 @@ package config
 // code for managing groundseg and container configurations
 
 import (
+	"encoding/base64"
 	"encoding/json"
 	"fmt"
 	"goseg/defaults"
 	"goseg/structs"
 	"io/ioutil"
 	"log/slog"
+	"math/rand"
 	"net"
 	"os"
 	"path/filepath"
@@ -70,12 +72,14 @@ func init() {
 		}
 		// generate and insert wireguard keys
 		wgPriv, wgPub, err := WgKeyGen()
+		salt := RandString(32)
 		if err != nil {
 			logger.Error(fmt.Sprintf("%v", err))
 		} else {
 			err = UpdateConf(map[string]interface{}{
 				"Pubkey":  wgPub,
 				"Privkey": wgPriv,
+				"Salt":    salt,
 			})
 			if err != nil {
 				logger.Error(fmt.Sprintf("%v", err))
@@ -208,3 +212,14 @@ func NetCheck(netCheck string) bool {
 	}
 	return internet
 }
+
+// generates a random secret string of the input length
+func RandString(length int) string {
+	randBytes := make([]byte, length)
+	_, err := rand.Read(randBytes)
+	if err != nil {
+		logger.Warn("Random error :s")
+		return ""
+	}
+	return base64.URLEncoding.EncodeToString(randBytes)
+}

+ 6 - 6
structs/ws.go

@@ -7,17 +7,17 @@ type WsType struct {
 }
 
 type WsPayload struct {
-	ID      string      `json:"id"`
-	Payload interface{} `json:"payload"`
-	Token   TokenStruct `json:"token"`
+	ID      string        `json:"id"`
+	Payload interface{}   `json:"payload"`
+	Token   WsTokenStruct `json:"token"`
 }
 
-type TokenStruct struct {
+type WsTokenStruct struct {
 	ID    string `json:"id"`
 	Token string `json:"token"`
 }
 
-type LoginPayload struct {
+type WsLoginPayload struct {
 	Type     string `json:"type"`
 	Password string `json:"password"`
-}
+}

+ 51 - 25
ws/ws.go

@@ -40,12 +40,11 @@ var (
 //     // rest of logic
 // }
 
-
 // switch on ws event cases
 func WsHandler(w http.ResponseWriter, r *http.Request) {
 	conn, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
-		fmt.Println(err)
+		logger.Error(fmt.Sprintf("Couldn't upgrade websocket connection: %v", err))
 		return
 	}
 	// manage broadcasts and clients thru the broadcast package
@@ -75,37 +74,20 @@ func WsHandler(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 		var prelim structs.WsType
+		var payload structs.WsPayload
 		if err := json.Unmarshal(msg, &prelim); err != nil {
 			fmt.Println("Error unmarshalling message:", err)
 			continue
 		}
 		switch prelim.Payload.Type {
 		case "login":
-			var payload structs.WsPayload
-			payload.Payload = structs.LoginPayload{}
-			if err := json.Unmarshal(msg, &payload); err != nil {
-				fmt.Println("Error unmarshalling message:", err)
-				continue
+			if err = loginHandler(msg, payload); err != nil {
+				logger.Error(fmt.Sprintf("%v", err))
 			}
-			loginPayload, ok := payload.Payload.(structs.LoginPayload)
-			if !ok {
-				fmt.Println("Error casting to LoginPayload")
-				continue
+		case "verify":
+			if err = verifyHandler(msg, payload, r, conn); err != nil {
+				logger.Error(fmt.Sprintf("%v", err))
 			}
-			isAuthenticated := auth.AuthenticateLogin(loginPayload.Password)
-			if isAuthenticated {
-				_, err := auth.CreateToken(conn, false)
-				if err != nil {
-					fmt.Println("Error creating token:", err)
-					continue
-				}
-				auth.AuthenticatedClients.Lock()
-				auth.AuthenticatedClients.Conns[conn] = true
-				auth.AuthenticatedClients.Unlock()
-                } else {
-                    // Handle invalid login
-                    fmt.Println("Invalid login attempt")
-                }
 		case "setup":
 			logger.Info("Setup")
 			// setup.Setup(payload)
@@ -134,3 +116,47 @@ func WsHandler(w http.ResponseWriter, r *http.Request) {
 		}
 	}
 }
+
+func loginHandler(msg []byte, payload structs.WsPayload) error {
+	logger.Info("Login")
+	now := time.Now().Format("2006-01-02_15:04:05")
+	payload.Payload = structs.WsLoginPayload{}
+	if err := json.Unmarshal(msg, &payload); err != nil {
+		return fmt.Errorf("Error unmarshalling message: %v", err)
+	}
+	loginPayload, ok := payload.Payload.(structs.WsLoginPayload)
+	if !ok {
+		return fmt.Errorf("Error casting to LoginPayload")
+	}
+	isAuthenticated := auth.AuthenticateLogin(loginPayload.Password)
+	if isAuthenticated {
+		if err := auth.AddSession(payload.Token.ID, payload.Token.Token, now, true); err != nil {
+			return fmt.Errorf("Unable to process login: %v", err)
+		}
+	} else {
+		logger.Info("Login failed")
+		if err := auth.AddSession(payload.Token.ID, payload.Token.Token, now, false); err != nil {
+			return fmt.Errorf("Unable to process login: %v", err)
+		}
+	}
+	return nil
+}
+
+func verifyHandler(msg []byte, payload structs.WsPayload, r *http.Request, conn *websocket.Conn) error {
+	payload.Payload = structs.WsLoginPayload{}
+	// if we can't unmarshal, assume no token
+	if err := json.Unmarshal(msg, &payload); err != nil {
+		resp, err := auth.CreateToken(conn, r, false)
+		if err != nil {
+			fmt.Errorf("Couldn't create token: %v", err)
+		}
+		respJson, err := json.Marshal(resp)
+		if err != nil {
+			return fmt.Errorf("Error marshalling token: %v", err)
+		}
+		if err := conn.WriteMessage(websocket.TextMessage, respJson); err != nil {
+			return fmt.Errorf("Error writing response: %v", err)
+		}
+	}
+	return nil
+}