• R/O
  • HTTP
  • SSH
  • HTTPS

Commit

Tags
No Tags

Frequently used words (click to add to your profile)

javac++androidlinuxc#windowsobjective-ccocoa誰得qtpythonphprubygameguibathyscaphec計画中(planning stage)翻訳omegatframeworktwitterdomtestvb.netdirectxゲームエンジンbtronarduinopreviewer

Emergent generative agents


Commit MetaInfo

Revisão35fbd054cec56752630a2152d93e3ac3e10e3830 (tree)
Hora2023-04-15 01:38:46
AutorCorbin <cds@corb...>
CommiterCorbin

Mensagem de Log

Give sentence indices an inigo behavior.

The exponentially-decaying behavior appears to work well with the
given parsimony metric. I spitballed the exponent, based on a couple
thought files from some bots.

Temperature really should be lower, so I tweaked that too.

Mudança Sumário

Diff

--- a/agent.py
+++ b/agent.py
@@ -1,6 +1,7 @@
11 #!/usr/bin/env nix-shell
22 #! nix-shell -i python3 -p python3Packages.irc python3Packages.faiss python3Packages.transformers python3Packages.torch
33
4+from collections import defaultdict
45 from concurrent.futures import ThreadPoolExecutor
56 from datetime import datetime
67 from heapq import nsmallest
@@ -10,13 +11,10 @@ import random
1011 import re
1112 import sys
1213
13-from faiss import IndexFlatL2
14-import numpy as np
15-
1614 from irc.bot import SingleServerIRCBot
1715 from irc.strings import lower
1816
19-from common import Log, Timer, breakAt
17+from common import Log, SentenceIndex, breakAt
2018 from gens.trans import Flavor, HFGen
2119 from gens.camelid import CamelidGen
2220
@@ -46,29 +44,6 @@ def load_character(path):
4644 with open(os.path.join(path, "character.json"), "r") as handle:
4745 return json.load(handle)
4846
49-class SentenceIndex:
50- def __init__(self, path, dimensions):
51- self.path = path
52- self.index = IndexFlatL2(dimensions)
53-
54- def load(self):
55- with open(self.path, "r") as handle:
56- data = json.load(handle)
57- self.db = list(data.items())
58- self.index.add(np.array([row[1] for row in self.db], dtype="float32"))
59-
60- def save(self):
61- with open(self.path, "w") as f: json.dump(dict(self.db), f)
62-
63- def search(self, embedding, k):
64- with Timer("%d nearest neighbors" % k):
65- D, I = self.index.search(np.array([embedding], dtype="float32"), k)
66- return [self.db[i][0] for i in I[0] if i >= 0]
67-
68- def add(self, s, embedding):
69- self.index.add(np.array([embedding], dtype="float32"))
70- self.db.append((s, embedding))
71-
7247 logpath = sys.argv[2]
7348 character = load_character(logpath)
7449 startingChannels = character.pop("startingChannels")
@@ -88,7 +63,7 @@ max_context_length = gen.contextLength()
8863 thought_index = SentenceIndex(os.path.join(logpath, "thoughts.json"),
8964 llama_gen.embedding_width)
9065 thought_index.load()
91-print("~ Thought index:", thought_index.index.ntotal, "thoughts")
66+print("~ Thought index:", thought_index.size(), "thoughts")
9267
9368 executor = ThreadPoolExecutor(max_workers=1)
9469
@@ -108,6 +83,7 @@ class Agent(SingleServerIRCBot):
10883 self.startingChannels = startingChannels
10984 self.logpath = logpath
11085 self.logs = {}
86+ self.willReply = defaultdict(bool)
11187
11288 def on_join(self, c, e):
11389 channel = e.target
@@ -132,21 +108,19 @@ class Agent(SingleServerIRCBot):
132108 # https://github.com/jaraco/irc/blob/main/scripts/testbot.py
133109 nick = lower(self.connection.get_nickname())
134110 lowered = lower(line)
135- if self.thinking: print("~ Already thinking")
136- else:
137- self.thinkAbout(channel)
138- if (nick in lowered and random.random() <= 0.875):
139- self.generateReply(c, channel)
140- elif random.random() <= 0.125: self.generateReply(c, channel)
111+ if nick in lowered or random.random() <= 0.125:
112+ self.willReply[channel] = True
113+ if not self.thinking: self.thinkAbout(c, channel)
141114
142115 def thoughtPrompt(self):
143116 key = NO_THOUGHTS_EMBED if self.recent_thought is None else self.recent_thought[1]
144- # Fetch more thoughts than necessary, and always prefer shorter
117+ # Fetch more thoughts than necessary, and then always prefer shorter
145118 # thoughts. This is an attempt to prevent exponential rumination.
146- new_thoughts = thought_index.search(key, 10)
147- # .search() returns most relevant thoughts first; reversing the list
148- # creates more focused chains of thought.
149- new_thoughts = nsmallest(5, new_thoughts.reverse(), key=len)
119+ new_thoughts = thought_index.search(key, 20)
120+ # XXX .search() returns most relevant thoughts first; reversing the list
121+ # would create more focused chains of thought.
122+ # Smaller thoughts are better.
123+ new_thoughts = nsmallest(10, new_thoughts, key=len)
150124 if self.recent_thought is not None:
151125 new_thoughts.append(self.recent_thought[0])
152126 print("~ Thoughts:", *new_thoughts)
@@ -180,8 +154,10 @@ Users: {users}"""
180154 return "\n".join(lines)
181155
182156 def generateReply(self, c, channel):
157+ print("~ Will reply to channel:", channel)
158+ self.willReply.pop(channel, None)
183159 log = self.logs[channel]
184- nick = self.connection.get_nickname()
160+ nick = c.get_nickname()
185161 prefix = f"{datetime.now():%H:%M:%S} <{nick}>"
186162 examples = self.examplesFromOtherChannels(channel)
187163 # NB: "full" prompt needs log lines from current channel...
@@ -190,18 +166,20 @@ Users: {users}"""
190166 log.bumpCutoff(max_context_length, gen.countTokens, fullPrompt, prefix)
191167 # ...and current channel's log lines are added here.
192168 s = log.finishPrompt(fullPrompt, prefix)
193- print("~ log length:", len(log.l) - log.cutoff,
194- "prompt length (tokens):", gen.countTokens(s))
169+ # print("~ log length:", len(log.l) - log.cutoff,
170+ # "prompt length (tokens):", gen.countTokens(s))
195171 # NB: At this point, execution is kicked out to a thread.
196172 def cb(completion):
197173 self.thinking = False
198174 reply = breakIRCLine(completion.result())
199175 log.irc(datetime.now(), nick, reply)
200176 c.privmsg(channel, reply)
177+ if self.willReply:
178+ self.generateReply(c, next(iter(self.willReply)))
201179 self.thinking = True
202180 executor.submit(lambda: gen.complete(s)).add_done_callback(cb)
203181
204- def thinkAbout(self, channel):
182+ def thinkAbout(self, c, channel):
205183 print("~ Will ponder channel:", channel)
206184 s = prompt + self.newThoughtPrompt(channel)
207185 def cb(completion):
@@ -211,6 +189,8 @@ Users: {users}"""
211189 embedding = llama_gen.embed(thought)
212190 self.recent_thought = thought, embedding
213191 thought_index.add(thought, embedding)
192+ thought_index.prune()
193+ if self.willReply[channel]: self.generateReply(c, channel)
214194 self.thinking = True
215195 executor.submit(lambda: gen.complete(s)).add_done_callback(cb)
216196
--- a/append_thought.py
+++ b/append_thought.py
@@ -1,21 +1,24 @@
11 #!/usr/bin/env nix-shell
2-#! nix-shell -i python3 -p python3
2+#! nix-shell -i python3 -p python3Packages.faiss
33
4-import json, sys
4+import sys
55
6+from common import SentenceIndex
67 from gens.camelid import CamelidGen
78
89 path = sys.argv[-1]
910 gen = CamelidGen()
11+index = SentenceIndex(path, gen.embedding_width)
1012
11-with open(path, "r") as handle: db = json.load(handle)
12-print("Thought database:", len(db), "entries")
13+index.load()
14+print("Thought database:", index.size(), "entries")
1315
1416 while True:
1517 try: thought = input("> ").strip()
1618 except EOFError: break
1719 if not thought: break
18- db[thought] = gen.embed(thought)
20+ index.add(thought, gen.embed(thought))
21+ index.prune()
1922
20-print("Saving thought database:", len(db), "entries")
21-with open(path, "w") as handle: json.dump(db, handle)
23+print("Saving thought database:", index.size(), "entries")
24+index.save()
--- a/common.py
+++ b/common.py
@@ -1,6 +1,11 @@
11 from bisect import bisect
2+import json
3+import random
24 from time import perf_counter
35
6+from faiss import IndexFlatL2
7+import numpy as np
8+
49 class Timer:
510 "Basic context manager for timing an operation."
611 def __init__(self, label): self.l = label
@@ -59,3 +64,48 @@ def parseLine(line, speakers):
5964 for edge in speakers:
6065 if not line.startswith(edge): line = breakAt(line, edge)
6166 return line.strip()
67+
68+def parsimony(s): return 1 - 2 ** -(len(s) * (1 / 50))
69+
70+class SentenceIndex:
71+ def __init__(self, path, dimensions):
72+ self.path = path
73+ self.dimensions = dimensions
74+ self.index = None
75+
76+ def size(self): return self.index.ntotal
77+
78+ def rebuild(self):
79+ with Timer("rebuilding sentence index"):
80+ self.index = IndexFlatL2(self.dimensions)
81+ self.index.add(np.array([row[1] for row in self.db], dtype="float32"))
82+
83+ def load(self):
84+ with open(self.path, "r") as handle:
85+ data = json.load(handle)
86+ self.db = list(data.items())
87+ self.rebuild()
88+
89+ def save(self):
90+ with open(self.path, "w") as f: json.dump(dict(self.db), f)
91+
92+ def search(self, embedding, k):
93+ with Timer("%d nearest neighbors" % k):
94+ D, I = self.index.search(np.array([embedding], dtype="float32"), k)
95+ return [self.db[i][0] for i in I[0] if i >= 0]
96+
97+ def add(self, s, embedding):
98+ self.index.add(np.array([embedding], dtype="float32"))
99+ self.db.append((s, embedding))
100+
101+ def prune(self):
102+ # NB: This is the same maths as an inigo. Older thoughts are less
103+ # likely to be removed.
104+ i = int(random.expovariate(1 / 5))
105+ if not (0 < i < len(self.db)): return
106+ thought = self.db[-i][0]
107+ print("~ Is is short enough?", thought)
108+ if random.random() <= parsimony(thought):
109+ print("~ Would prune thought:", thought)
110+ # self.db.pop(-i)
111+ # self.rebuild()
--- a/gens/trans.py
+++ b/gens/trans.py
@@ -57,6 +57,6 @@ class HFGen:
5757 do_sample=True,
5858 # Force responses.
5959 min_length=5,
60- # Slightly sharpen results.
61- temperature=0.875, repetition_penalty=1.0625,
60+ # Sharpen results.
61+ temperature=0.75, repetition_penalty=1.125,
6262 )[0]["generated_text"]