diff --git a/internal/state/state.go b/internal/state/state.go new file mode 100644 index 0000000..df80bbb --- /dev/null +++ b/internal/state/state.go @@ -0,0 +1,396 @@ +package state + +/* STATE + * This package encapsulates various state information + * state is represented in various global singletons + * Additionally, this package offers a pub/sub interface + * for various aspects of state + */ + +import ( + "slices" + "sync" + "time" + + "github.com/bwmarrin/discordgo" +) + +/* Event interface is meant to encapsulate a general interface + * extendable for any different action the bot takes or that a + * user takes. + */ + +const ( + BadEventTypeError = "bad event type" + EventValidationFailedError = "event failed validation: " +) + +var DiscordSession *discordgo.Session +var eventMutex sync.RWMutex +var eventSubscriptionCache = [NumEventTypes][]chan Event{} +var eventCache = []Event{} + +// TODO: Configurable +var eventChannelBufferSize = 256 + +// TODO: pruned events go to DISK +// ASSUMES eventCache is ordered by age (oldest first) +func pruneEventCache() { + eventMutex.Lock() + defer eventMutex.Unlock() + + oldCacheInvalid := false + newCache := []Event{} + for _, obj := range eventCache { + if time.Since(obj.Time()).Hours() > 24*10 { + oldCacheInvalid = true + } else if oldCacheInvalid { + newCache = append(newCache, obj) + } + } + + eventCache = newCache +} + +type EventType int8 + +const ( + Vote EventType = iota + Challenge + Restoration + UserActive + // ... + + // leave this last + NumEventTypes +) + +func (et EventType) Validate() error { + if et < 0 || et >= NumEventTypes { + return stringErrorType(BadEventTypeError) + } + + return nil +} + +type Event interface { + // gets EventType associated with event + Type() EventType + + // gets time event created + Time() time.Time + + // gets internal metadata associated with event + // may be read only depending on event implementation + Data() map[string]string + + // validates state of internal metadata per EventType + Validate() error +} + +// TODO: Something better than this +type stringErrorType string + +func (s stringErrorType) Error() string { + return string(s) +} + +/* adds a new subscriber channel to the event subscription cache + * and returns the channel that it will publish notifications on + * + * note: channel is prefilled with at most eventChannelBufferSize + * historical events. Truncated if history exceeds event channel + * buffer size. + */ +func (et EventType) SubscribeWithHistory() (chan Event, error) { + if err := et.Validate(); err != nil { + return nil, err + } + + ch := make(chan Event, eventChannelBufferSize) + + eventMutex.Lock() + eventSubscriptionCache[et] = append( + eventSubscriptionCache[et], + ch, + ) + eventMutex.Unlock() + + eventMutex.RLock() + defer eventMutex.RUnlock() + numEventsAdded := 0 + for _, ev := range slices.Backward(eventCache) { + if numEventsAdded >= eventChannelBufferSize { + break + } + + if ev.Type() == et { + ch <- ev + numEventsAdded += 1 + } + } + + return ch, nil +} + +/* adds a new subscriber channel to the event subscription cache + * and returns the channel that it will publish notifications on + */ +func (et EventType) Subscribe() (chan Event, error) { + if err := et.Validate(); err != nil { + return nil, err + } + + ch := make(chan Event, eventChannelBufferSize) + + eventMutex.Lock() + defer eventMutex.Unlock() + eventSubscriptionCache[et] = append( + eventSubscriptionCache[et], + ch, + ) + + return ch, nil +} + +func PublishEvent(e Event) error { + if err := ValidateEvent(e); err != nil { + return stringErrorType( + EventValidationFailedError + err.Error(), + ) + } + + eventMutex.Lock() + eventCache = append(eventCache, e) + eventMutex.Unlock() + eventMutex.RLock() + defer eventMutex.RUnlock() + for _, c := range eventSubscriptionCache[e.Type()] { + if float32(len(c)) > (float32(eventChannelBufferSize) * 0.25) { + if len(c) == eventChannelBufferSize { + // TODO: log that this publish is blocking + // on a full channel + } + } + c <- e + } + + return nil +} + +func ValidateEvent(e Event) error { + if err := e.Type().Validate(); err != nil { + return err + } + + if err := e.Validate(); err != nil { + return err + } + + return nil +} + +/* gets all events of type T in cache + * that also share all field values in + * map 'filters' + */ +func GetMatchingEvents( + t EventType, + filters map[string]string, +) ([]Event, error) { + matches := []Event{} + + if err := t.Validate(); err != nil { + return matches, err + } + + eventMutex.RLock() + defer eventMutex.RUnlock() + +Events: + for _, e := range eventCache { + if e.Type() != t { + continue + } + for k, v := range filters { + val, found := e.Data()[k] + if !found || val != v { + continue Events + } + } + + matches = append(matches, e) + } + + return matches, nil +} + +type VoteEvent map[string]string + +const ( + VoteMissingKeyError = "vote data not found: " + VoteCreatedKey = "created" + VoteRequesterKey = "requester" + VoteActionKey = "action" + + VoteResultKey = "result" + VoteResultPass = "pass" + VoteResultFail = "fail" + VoteResultTie = "tie" + VoteResultTimeout = "timeout" + VoteBadResultError = "vote has invalid result: " + VoteNotFinishedError = "vote has result but isnt finished" + VoteMissingResultError = "vote finished but missing result" + + VoteStatusKey = "status" + VoteStatusInProgress = "in_progress" + VoteStatusFinalized = "finalized" + VoteStatusTimeout = "timed_out" + VoteBadStatusError = "vote has invalid status: " + + VeryOldVote = "1990-01-01T00:00:00Z" +) + +func (ve VoteEvent) Type() EventType { + return Vote +} + +func (ve VoteEvent) Time() time.Time { + t, e := time.Parse(time.RFC3339, ve[VoteCreatedKey]) + if e != nil { + // we have a corrupted event + // TODO: log first + // return old time so that this event gets + // pruned from cache + tooOld, _ := time.Parse( + time.RFC3339, + VeryOldVote, + ) + return tooOld + } + return t +} + +func (ve VoteEvent) Data() map[string]string { + return map[string]string(ve) +} + +func (ve VoteEvent) Validate() error { + // make sure action, requester, and created are set + for _, key := range []string{ + VoteActionKey, + VoteRequesterKey, + VoteCreatedKey, + VoteStatusKey, + } { + if _, found := ve[key]; !found { + return stringErrorType( + VoteMissingKeyError + key) + } + } + + status := ve[VoteStatusKey] + if status != VoteStatusTimeout && + status != VoteStatusInProgress && + status != VoteStatusFinalized { + return stringErrorType(VoteBadStatusError + status) + } + + result, hasResult := ve[VoteResultKey] + if hasResult && status == VoteStatusInProgress { + return stringErrorType(VoteNotFinishedError) + } + if status != VoteStatusInProgress && !hasResult { + return stringErrorType(VoteMissingResultError) + } + + if hasResult && + (result != VoteResultPass && + result != VoteResultFail && + result != VoteResultTie && + result != VoteResultTimeout) { + + return stringErrorType(VoteBadResultError + result) + } + + return nil +} + +type UserEvent struct { + uid string + created time.Time +} + +const ( + UserEventUserKey = "user" + UserEventCreatedKey = "created" + UserEventBadUserError = "event has bad user" +) + +func (ue UserEvent) Time() time.Time { + return ue.created +} + +func (ue UserEvent) Data() map[string]string { + return map[string]string{ + UserEventUserKey: ue.uid, + UserEventCreatedKey: ue.created.Local().String(), + } +} + +func (ue UserEvent) Validate() error { + if DiscordSession != nil { + _, err := DiscordSession.User(ue.uid) + return err + } else { + // TODO: Log validation failure + // I would love to know how to actually fail here + // and still have unit testable code. + return nil + } +} + +type ChallengeEvent struct { + UserEvent +} + +func (ce ChallengeEvent) Type() EventType { + return Challenge +} + +func NewChallengeEvent(user string) ChallengeEvent { + return ChallengeEvent{UserEvent{ + uid: user, + created: time.Now(), + }} +} + +type RestorationEvent struct { + UserEvent +} + +func (re RestorationEvent) Type() EventType { + return Restoration +} + +func NewRestorationEvent(user string) RestorationEvent { + return RestorationEvent{UserEvent{ + uid: user, + created: time.Now(), + }} +} + +type UserActiveEvent struct { + UserEvent +} + +func (ua UserActiveEvent) Type() EventType { + return UserActive +} + +func NewUserActiveEvent(user string) UserActiveEvent { + return UserActiveEvent{UserEvent{ + uid: user, + created: time.Now(), + }} +} diff --git a/internal/state/state_test.go b/internal/state/state_test.go new file mode 100644 index 0000000..3d98b1c --- /dev/null +++ b/internal/state/state_test.go @@ -0,0 +1,213 @@ +package state + +import ( + "fmt" + "testing" + "time" +) + +/* WARNING: + * Cannot run these tests in parallel! + * limitation of SetupTest and CleanupTest + */ + +const TestTok = "TEST_NAME" + +func SetupTest(t *testing.T) { + old, _ := time.Parse( + time.RFC3339, + VeryOldVote, + ) + + for i := range 270 { + if err := PublishEvent(UserActiveEvent{UserEvent{ + uid: fmt.Sprintf("%d", i), + created: old, + }}); err != nil { + t.Errorf("Failed to add event: %e", err) + } + } + + PublishEvent(UserActiveEvent{UserEvent{ + uid: fmt.Sprintf(TestTok), + created: time.Now(), + }}) + + PublishEvent(ChallengeEvent{UserEvent{ + uid: fmt.Sprintf(TestTok), + created: time.Now(), + }}) + + if len(eventCache) != 272 { + t.Errorf("Unexpected number of events in cache: %d", + len(eventCache)) + } +} + +func CleanupTest() { + eventSubscriptionCache = [NumEventTypes][]chan Event{} + eventCache = []Event{} +} + +func TestPubSub(t *testing.T) { + SetupTest(t) + + c, e := UserActive.SubscribeWithHistory() + if e != nil { + t.Errorf("Error subscribing to UserActive events: %e", e) + } + +Loop: + for i := 0; true; i++ { + select { + case e, ok := <-c: + if !ok { + t.Errorf("Subscription Channel Closed") + } + if e.Type() != UserActive { + t.Errorf("Non UserActive Event in UserActive subscription: %v", e.Type()) + } + default: + if i == eventChannelBufferSize { + break Loop + } else { + t.Errorf("Unexpected number of events in channel: %d", i) + } + } + } + + PublishEvent(UserActiveEvent{UserEvent{ + uid: "uniqueToken", + created: time.Now(), + }}) + + select { + case e, ok := <-c: + if !ok || e.Data()[UserEventUserKey] != "uniqueToken" { + t.Errorf("didnt read correct event from channel: %v", e) + } + default: + t.Errorf("New event not published to subscription!") + } + + CleanupTest() +} + +func TestFilterCache(t *testing.T) { + SetupTest(t) + + events, err := GetMatchingEvents( + UserActive, + map[string]string{ + UserEventUserKey: TestTok, + }, + ) + + if err != nil { + t.Errorf("Error filtering events: %e", err) + } + + if len(events) != 1 { + t.Errorf("Got too many events from filter: %d", len(events)) + } + + if events[0].Type() != UserActive { + t.Errorf("Got wrong event!: %+v", events[0]) + } + + CleanupTest() +} + +func TestPruneCache(t *testing.T) { + SetupTest(t) + pruneEventCache() + + if len(eventCache) != 2 { + t.Errorf("Incorrect number of remaining events: %d", len(eventCache)) + } + + CleanupTest() +} + +func TestVoteEventValidations(t *testing.T) { + var err error + + if err = VoteEvent( + map[string]string{ + VoteRequesterKey: "r", + VoteCreatedKey: "c", + VoteStatusKey: VoteStatusInProgress, + }, + ).Validate(); err.Error() != VoteMissingKeyError+VoteActionKey { + t.Errorf("Unexpected error from validation: %e", err) + } + + if err = VoteEvent( + map[string]string{ + VoteActionKey: "a", + VoteRequesterKey: "r", + VoteCreatedKey: "c", + VoteStatusKey: VoteStatusInProgress, + }, + ).Validate(); err != nil { + t.Errorf("Unexpected error: %e", err) + } + + if err = VoteEvent( + map[string]string{ + VoteActionKey: "a", + VoteRequesterKey: "r", + VoteCreatedKey: "c", + VoteStatusKey: "s", + }, + ).Validate(); err.Error() != VoteBadStatusError+"s" { + t.Errorf("Unexpected or no error: %e", err) + } + + if err = VoteEvent( + map[string]string{ + VoteActionKey: "a", + VoteRequesterKey: "r", + VoteCreatedKey: "c", + VoteStatusKey: VoteStatusInProgress, + VoteResultKey: VoteResultFail, + }, + ).Validate(); err.Error() != VoteNotFinishedError { + t.Errorf("Unexpected or no error: %e", err) + } + + if err = VoteEvent( + map[string]string{ + VoteActionKey: "a", + VoteRequesterKey: "r", + VoteCreatedKey: "c", + VoteStatusKey: VoteStatusFinalized, + }, + ).Validate(); err.Error() != VoteMissingResultError { + t.Errorf("Unexpected or no error: %e", err) + } + + if err = VoteEvent( + map[string]string{ + VoteActionKey: "a", + VoteRequesterKey: "r", + VoteCreatedKey: "c", + VoteStatusKey: VoteStatusFinalized, + VoteResultKey: "r", + }, + ).Validate(); err.Error() != VoteBadResultError+"r" { + t.Errorf("Unexpected or no error: %e", err) + } + + if err = VoteEvent( + map[string]string{ + VoteActionKey: "a", + VoteRequesterKey: "r", + VoteCreatedKey: "c", + VoteStatusKey: VoteStatusFinalized, + VoteResultKey: VoteResultFail, + }, + ).Validate(); err != nil { + t.Errorf("Unexpected or no error: %e", err) + } +} diff --git a/main.go b/main.go index 321ae32..3ece662 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "github.com/bwmarrin/discordgo" "gitlab.com/whom/bingobot/internal/config" "gitlab.com/whom/bingobot/internal/logging" + "gitlab.com/whom/bingobot/internal/state" ) var ( @@ -44,9 +45,9 @@ func main() { } func startBot() error { - session, _ := discordgo.New("Bot " + *Token) + state.DiscordSession, _ = discordgo.New("Bot " + *Token) - err := session.Open() + err := state.DiscordSession.Open() if err != nil { Log.Error("could not open discord session", "type", "error", "error", err) return err @@ -60,7 +61,7 @@ func startBot() error { Log.Info("shutting down gracefully", "type", "shutdown") - err = session.Close() + err = state.DiscordSession.Close() if err != nil { Log.Error("could not close discord session gracefully", "type", "error", "error", err) return err