From cdfbc3769f8da6cef3180a9af9b7ecdbe6c13750 Mon Sep 17 00:00:00 2001 From: Tanner Kvarfordt Date: Sat, 19 Feb 2022 17:23:04 -0700 Subject: [PATCH] Mad Libs (#51) * Added /madlib * madlib module now reads model and mask from config * Updated README --- README.md | 1 + config/madlib.json | 4 ++ go.mod | 14 +++--- go.sum | 16 +++++++ kardbot/commands.go | 13 +++++ kardbot/madlibs.go | 113 ++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 154 insertions(+), 7 deletions(-) create mode 100644 config/madlib.json create mode 100644 kardbot/madlibs.go diff --git a/README.md b/README.md index 999135d..7a614e1 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ A discord bot destined for greatness. - [x] Generate a story from a user's prompt - [x] Allow server admins to generate and edit a role selection menu - [x] Allow users to create embeds +- [x] Madlibs - [ ] Inform users when Kard-bot is updated - [ ] Mock certain questions or phrases - [ ] "Quack" any time a user types an expletive diff --git a/config/madlib.json b/config/madlib.json new file mode 100644 index 0000000..007d273 --- /dev/null +++ b/config/madlib.json @@ -0,0 +1,4 @@ +{ + "model": "roberta-base", + "model-mask": "" +} \ No newline at end of file diff --git a/go.mod b/go.mod index 5e97c73..247929b 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,13 @@ module github.com/TannerKvarfordt/Kard-bot go 1.17 require ( - github.com/TannerKvarfordt/hfapigo v0.0.1 + github.com/TannerKvarfordt/hfapigo v0.0.3 github.com/TannerKvarfordt/imgflipgo v1.0.5 github.com/bwmarrin/discordgo v0.23.3-0.20211228023845-29269347e820 github.com/deadshot465/owoify-go v1.0.1 - github.com/forPelevin/gomoji v1.1.1 + github.com/forPelevin/gomoji v1.1.2 github.com/gabriel-vasile/mimetype v1.4.0 - github.com/go-co-op/gocron v1.11.0 + github.com/go-co-op/gocron v1.12.0 github.com/google/uuid v1.3.0 github.com/joho/godotenv v1.4.0 github.com/lucasb-eyer/go-colorful v1.2.0 @@ -24,14 +24,14 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/gorilla/schema v1.2.0 // indirect - github.com/gorilla/websocket v1.4.2 // indirect + github.com/gorilla/websocket v1.5.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect - golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 // indirect - golang.org/x/net v0.0.0-20211216030914-fe4d6282115f // indirect + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect + golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect - golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e // indirect + golang.org/x/sys v0.0.0-20220209214540-3681064d5158 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.27.1 // indirect ) diff --git a/go.sum b/go.sum index 6adf463..b19a2f6 100644 --- a/go.sum +++ b/go.sum @@ -65,6 +65,10 @@ github.com/OpenPeeDeeP/depguard v1.0.1/go.mod h1:xsIw86fROiiwelg+jB2uM9PiKihMMmU github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/TannerKvarfordt/hfapigo v0.0.1 h1:VsFZx76js8QSfqS8dyej6RCcwwrcK0f20Jt/6YAVAg8= github.com/TannerKvarfordt/hfapigo v0.0.1/go.mod h1:XKhHnldnhSZNCrRZclWr2D/A/UlEGaQ8OxpM7r8wQpE= +github.com/TannerKvarfordt/hfapigo v0.0.2 h1:1rUFzQDPwq0mGj1pboqbG28nHB0NitWg4Ct6EKi6ZgM= +github.com/TannerKvarfordt/hfapigo v0.0.2/go.mod h1:XKhHnldnhSZNCrRZclWr2D/A/UlEGaQ8OxpM7r8wQpE= +github.com/TannerKvarfordt/hfapigo v0.0.3 h1:qskK5VwlqU/BrBWzXRUDPut34gmarY4bWuVOlutTaKM= +github.com/TannerKvarfordt/hfapigo v0.0.3/go.mod h1:XKhHnldnhSZNCrRZclWr2D/A/UlEGaQ8OxpM7r8wQpE= github.com/TannerKvarfordt/imgflipgo v1.0.5 h1:y9q6Vt4cq0bH3FvlDuEfqabZURtdcD7KYZ4EYdWDWIw= github.com/TannerKvarfordt/imgflipgo v1.0.5/go.mod h1:wvuZ3+UAyCtKuvqn0Thcjhj31yetO8/baXPBYI5i9Ug= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -159,6 +163,8 @@ github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4 github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/forPelevin/gomoji v1.1.1 h1:Jmt2USLSTkF5dFX7y8nDObS5Zs1NIdIOjHMCx6mime0= github.com/forPelevin/gomoji v1.1.1/go.mod h1:Z5cUlNvnrRQPxwMxc8hmn+ZAm0p8WhqE0FvDNnF5Mkw= +github.com/forPelevin/gomoji v1.1.2 h1:WtOeuEBAmwsGOvsPfvufTrdqmsJhQF77nGWveFQTDd8= +github.com/forPelevin/gomoji v1.1.2/go.mod h1:Z5cUlNvnrRQPxwMxc8hmn+ZAm0p8WhqE0FvDNnF5Mkw= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU= @@ -169,6 +175,8 @@ github.com/gabriel-vasile/mimetype v1.4.0/go.mod h1:fA8fi6KUiG7MgQQ+mEWotXoEOvmx github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-co-op/gocron v1.11.0 h1:ujOMubCpGcTxnnR/9vJIPIEpgwuAjbueAYqJRNr+nHg= github.com/go-co-op/gocron v1.11.0/go.mod h1:qtlsoMpHlSdIZ3E/xuZzrrAbeX3u5JtPvWf2TcdutU0= +github.com/go-co-op/gocron v1.12.0 h1:RahikbAIhp/wlNBraICMZfby7bdkeCXe+QQSW323Lpo= +github.com/go-co-op/gocron v1.12.0/go.mod h1:qtlsoMpHlSdIZ3E/xuZzrrAbeX3u5JtPvWf2TcdutU0= github.com/go-critic/go-critic v0.6.1/go.mod h1:SdNCfU0yF3UBjtaZGw6586/WocupMOJuiqgom5DsQxM= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -311,6 +319,8 @@ github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gostaticanalysis/analysisutil v0.0.0-20190318220348-4088753ea4d3/go.mod h1:eEOZF4jCKGi+aprrirO9e7WKB3beBRtWgqGunKl6pKE= github.com/gostaticanalysis/analysisutil v0.0.3/go.mod h1:eEOZF4jCKGi+aprrirO9e7WKB3beBRtWgqGunKl6pKE= github.com/gostaticanalysis/analysisutil v0.1.0/go.mod h1:dMhHRU9KTiDcuLGdy87/2gTR8WruwYZrKdRq9m1O6uw= @@ -704,6 +714,8 @@ golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 h1:0es+/5331RGQPcXlMfP+WrnIIS6dNnNRe0WB02W0F4M= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -795,6 +807,8 @@ golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211216030914-fe4d6282115f h1:hEYJvxw1lSnWIl8X9ofsYMklzaDs90JI2az5YMd4fPM= golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd h1:O7DYs+zxREGLKzKoMQrtrEacpb0ZVXA5rIwylE2Xchk= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -905,6 +919,8 @@ golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211013075003-97ac67df715c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220209214540-3681064d5158 h1:rm+CHSpPEEW2IsXUib1ThaHIjuBVZjxNgSKmBLFfD4c= +golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/kardbot/commands.go b/kardbot/commands.go index 03de0e8..be7a20b 100644 --- a/kardbot/commands.go +++ b/kardbot/commands.go @@ -298,6 +298,18 @@ func getCommands() []*discordgo.ApplicationCommand { }, }, }, + { + Name: madlibCmd, + Description: "The bot will fill in any blanks indicated with " + madlibBlank, + Options: []*discordgo.ApplicationCommandOption{ + { + Type: discordgo.ApplicationCommandOptionString, + Name: "prompt", + Description: fmt.Sprintf("An input prompt containing blanks. Example: 'The %s jumps over the %s.'", madlibBlank, madlibBlank), + Required: true, + }, + }, + }, { Name: storyTimeCmd, Description: "The bot will tell you a short story (but not a good or sensical one) based on a given prompt.", @@ -361,6 +373,7 @@ func getCommandImpls() map[string]onInteractionHandler { storyTimeCmd: storyTime, roleSelectMenuCommand: handleRoleSelectMenuCommand, embedCmd: handleEmbedCmd, + madlibCmd: handleMadLibCmd, } } diff --git a/kardbot/madlibs.go b/kardbot/madlibs.go new file mode 100644 index 0000000..5fbb4c4 --- /dev/null +++ b/kardbot/madlibs.go @@ -0,0 +1,113 @@ +package kardbot + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/TannerKvarfordt/Kard-bot/kardbot/config" + "github.com/TannerKvarfordt/hfapigo" + "github.com/bwmarrin/discordgo" + log "github.com/sirupsen/logrus" +) + +const ( + madlibCmd = "madlib" + madlibBlank = "<>" + madlibConfigFile = "config/madlib.json" +) + +type madlibConfig struct { + Model string `json:"model,omitempty"` + ModelMask string `json:"model-mask,omitempty"` +} + +var madlibCfg = madlibConfig{ + Model: "roberta-base", // https://huggingface.co/roberta-base + ModelMask: "", +} + +func init() { + jsonCfg, err := config.NewJsonConfig(madlibConfigFile) + if err != nil { + log.Fatal(err) + } + + err = json.Unmarshal(jsonCfg.Raw, &madlibCfg) + if err != nil { + log.Fatal(err) + } + + log.Infof("Madlib using %s with mask=%s", madlibCfg.Model, madlibCfg.ModelMask) +} + +func handleMadLibCmd(s *discordgo.Session, i *discordgo.InteractionCreate) { + wg := bot().updateLastActive() + defer wg.Wait() + + if isSelf, err := authorIsSelf(s, i); err != nil { + log.Error(err) + interactionRespondEphemeralError(s, i, true, err) + return + } else if isSelf { + log.Trace("Ignoring message from self") + return + } + + if !strings.Contains(i.ApplicationCommandData().Options[0].StringValue(), madlibBlank) { + s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ + Type: discordgo.InteractionResponseChannelMessageWithSource, + Data: &discordgo.InteractionResponseData{ + Content: fmt.Sprintf("No blanks provided. Provide a prompt string containing at least one of the following mask: `%s`.\nFor example: `The quick brown %s jumps over the lazy %s.`", madlibBlank, madlibBlank, madlibBlank), + Flags: InteractionResponseFlagEphemeral, + }, + }) + return + } + + err := s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ + Type: discordgo.InteractionResponseDeferredChannelMessageWithSource, + }) + if err != nil { + log.Error(err) + interactionRespondEphemeralError(s, i, true, err) + return + } + + input := strings.ReplaceAll(i.ApplicationCommandData().Options[0].StringValue(), madlibBlank, madlibCfg.ModelMask) + resp, err := hfapigo.SendFillMaskRequest(madlibCfg.Model, &hfapigo.FillMaskRequest{ + Inputs: []string{input}, + Options: *hfapigo.NewOptions().SetWaitForModel(true), + }) + if err != nil { + log.Error(err) + interactionFollowUpEphemeralError(s, i, true, err) + return + } + + if len(resp) < strings.Count(i.ApplicationCommandData().Options[0].StringValue(), madlibBlank) { + err := fmt.Errorf("too few responses received") + log.Error(err) + interactionFollowUpEphemeralError(s, i, true, err) + return + } + + output := input + for _, mask := range resp { + if len(mask.Masks) == 0 { + err := fmt.Errorf("received empty response") + log.Error(err) + interactionFollowUpEphemeralError(s, i, true, err) + return + } + output = strings.Replace(output, madlibCfg.ModelMask, strings.TrimSpace(mask.Masks[0].TokenStr), 1) + } + + _, err = s.InteractionResponseEdit(s.State.User.ID, i.Interaction, &discordgo.WebhookEdit{ + Content: output, + }) + if err != nil { + log.Error(err) + interactionFollowUpEphemeralError(s, i, true, err) + } +}