broadcast.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package broadcast
  2. import (
  3. "encoding/json"
  4. "goseg/config"
  5. "goseg/docker"
  6. "goseg/structs"
  7. "log/slog"
  8. "os"
  9. "reflect"
  10. "strings"
  11. "sync"
  12. "fmt"
  13. "github.com/gorilla/websocket"
  14. )
  15. var (
  16. logger = slog.New(slog.NewJSONHandler(os.Stdout, nil))
  17. clients = make(map[*websocket.Conn]bool)
  18. broadcastState structs.AuthBroadcast
  19. mu sync.RWMutex // synchronize access to broadcastState
  20. )
  21. func init() {
  22. // initialize broadcastState global var
  23. config := config.Conf()
  24. broadcast, err := bootstrapBroadcastState(config)
  25. if err != nil {
  26. errmsg := fmt.Sprintf("Unable to initialize broadcast: %v",err)
  27. panic(errmsg)
  28. }
  29. broadcastState = broadcast
  30. }
  31. // adds ws client
  32. func RegisterClient(conn *websocket.Conn) {
  33. clients[conn] = true
  34. broadcastJson, err := GetStateJson()
  35. if err != nil {
  36. return
  37. }
  38. // when a new ws client registers, send them the current broadcast
  39. if err := conn.WriteMessage(websocket.TextMessage, broadcastJson); err != nil {
  40. fmt.Println("Error writing response:", err)
  41. return
  42. }
  43. }
  44. // remove ws client
  45. func UnregisterClient(conn *websocket.Conn) {
  46. delete(clients, conn)
  47. }
  48. // take in config file and addt'l info to initialize broadcast
  49. func bootstrapBroadcastState(config structs.SysConfig) (structs.AuthBroadcast, error) {
  50. var res structs.AuthBroadcast
  51. piers := config.Piers
  52. pierStatus, err := docker.GetShipStatus(piers)
  53. if err != nil {
  54. errmsg := fmt.Sprintf("Unable to bootstrap urbit states: %v",err)
  55. logger.Error(errmsg)
  56. return res, err
  57. }
  58. updates := make(map[string]structs.Urbit)
  59. for pier, status := range pierStatus {
  60. urbit := structs.Urbit{}
  61. if existingUrbit, exists := broadcastState.Urbits[pier]; exists {
  62. // If the ship already exists in broadcastState, use its current state
  63. urbit = existingUrbit
  64. }
  65. isRunning := (status == "Up" || strings.HasPrefix(status, "Up "))
  66. urbit.Info.Running = isRunning
  67. updates[pier] = urbit
  68. }
  69. // update broadcastState
  70. err = UpdateBroadcastState(map[string]interface{}{
  71. "Urbits": updates,
  72. })
  73. if err != nil {
  74. errmsg := fmt.Sprintf("Unable to update broadcast state: %v", err)
  75. logger.Error(errmsg)
  76. return res, err
  77. }
  78. res = GetState()
  79. return res, nil
  80. }
  81. // update broadcastState with a map of items
  82. func UpdateBroadcastState(values map[string]interface{}) error {
  83. mu.Lock()
  84. defer mu.Unlock()
  85. v := reflect.ValueOf(&broadcastState).Elem()
  86. for key, value := range values {
  87. field := v.FieldByName(key)
  88. if !field.IsValid() || !field.CanSet() {
  89. return fmt.Errorf("field %s does not exist or is not settable", key)
  90. }
  91. if err := recursiveUpdate(field, reflect.ValueOf(value)); err != nil {
  92. return err
  93. }
  94. }
  95. BroadcastToClients()
  96. return nil
  97. }
  98. // this allows us to insert stuff into nested vals and not overwrite the existing contents
  99. func recursiveUpdate(dst, src reflect.Value) error {
  100. if !dst.CanSet() {
  101. return fmt.Errorf("field is not settable")
  102. }
  103. // If both dst and src are maps, handle them recursively
  104. if dst.Kind() == reflect.Map && src.Kind() == reflect.Map {
  105. for _, key := range src.MapKeys() {
  106. srcVal := src.MapIndex(key)
  107. // If the key doesn't exist in dst, initialize it
  108. dstVal := dst.MapIndex(key)
  109. if !dstVal.IsValid() {
  110. dstVal = reflect.New(dst.Type().Elem()).Elem()
  111. }
  112. // Recursive call to handle potential nested maps
  113. if err := recursiveUpdate(dstVal, srcVal); err != nil {
  114. return err
  115. }
  116. if dst.IsNil() {
  117. dst.Set(reflect.MakeMap(dst.Type()))
  118. }
  119. dst.SetMapIndex(key, dstVal)
  120. }
  121. return nil
  122. }
  123. // For non-map fields or direct updates
  124. if dst.Type() != src.Type() {
  125. return fmt.Errorf("type mismatch: expected %s, got %s", dst.Type(), src.Type())
  126. }
  127. dst.Set(src)
  128. return nil
  129. }
  130. // return broadcast state
  131. func GetState() structs.AuthBroadcast {
  132. mu.Lock()
  133. defer mu.Unlock()
  134. return broadcastState
  135. }
  136. // return json string of current broadcast state
  137. func GetStateJson() ([]byte, error) {
  138. mu.Lock()
  139. defer mu.Unlock()
  140. broadcastJson, err := json.Marshal(broadcastState)
  141. if err != nil {
  142. errmsg := fmt.Sprintf("Error marshalling response: %v", err)
  143. logger.Error(errmsg)
  144. return nil, err
  145. }
  146. return broadcastJson, nil
  147. }
  148. // broadcast the global state to all clients
  149. func BroadcastToClients() error {
  150. broadcastJson, err := json.Marshal(broadcastState)
  151. if err != nil {
  152. logger.Error("Error marshalling response:", err)
  153. return err
  154. }
  155. for client := range clients {
  156. if err := client.WriteMessage(websocket.TextMessage, broadcastJson); err != nil {
  157. logger.Error("Error writing response:", err)
  158. return err
  159. }
  160. }
  161. return nil
  162. }