• R/O
  • HTTP
  • SSH
  • HTTPS

Commit

Tags
No Tags

Frequently used words (click to add to your profile)

javac++androidlinuxc#windowsobjective-ccocoa誰得qtpythonphprubygameguibathyscaphec計画中(planning stage)翻訳omegatframeworktwitterdomtestvb.netdirectxゲームエンジンbtronarduinopreviewer

Emergent generative agents


Commit MetaInfo

Revisãobb8a464e79a4ccd7e0ef517bbdd226ac6446c5d5 (tree)
Hora2023-06-12 03:14:32
AutorCorbin <cds@corb...>
CommiterCorbin

Mensagem de Log

Bump RWKV and LLaMA dependencies.

Mudança Sumário

Diff

--- a/flake.lock
+++ b/flake.lock
@@ -5,11 +5,11 @@
55 "systems": "systems"
66 },
77 "locked": {
8- "lastModified": 1681202837,
9- "narHash": "sha256-H+Rh19JDwRtpVPAWp64F+rlEtxUWBAQW28eAi3SRSzg=",
8+ "lastModified": 1685518550,
9+ "narHash": "sha256-o2d0KcvaXzTrPRIo0kOLV0/QXHhDQ5DTi+OxcjO8xqY=",
1010 "owner": "numtide",
1111 "repo": "flake-utils",
12- "rev": "cfacdce06f30d2b68473a46042957675eebb3401",
12+ "rev": "a1720a10a6cfe8234c0e93907ffe81be440f4cef",
1313 "type": "github"
1414 },
1515 "original": {
@@ -39,11 +39,11 @@
3939 "nixpkgs": "nixpkgs"
4040 },
4141 "locked": {
42- "lastModified": 1682476640,
43- "narHash": "sha256-mLVO3T86AaXg3CXJNEjxjAaUGRhB0+8J1yGRV9LN/8M=",
42+ "lastModified": 1685666615,
43+ "narHash": "sha256-+hL7AHJja9EyGbwKOIB4yK/6QsMkSgyi/Dv0YpWl9wA=",
4444 "owner": "MostAwesomeDude",
4545 "repo": "llama.cpp",
46- "rev": "8a2a3e5098187a70a3949aa8a9351f2f26478d84",
46+ "rev": "a48b63971f6e4229c00cf23da4cc09881a8f6874",
4747 "type": "github"
4848 },
4949 "original": {
@@ -70,11 +70,11 @@
7070 },
7171 "nixpkgs_2": {
7272 "locked": {
73- "lastModified": 1682453498,
74- "narHash": "sha256-WoWiAd7KZt5Eh6n+qojcivaVpnXKqBsVgpixpV2L9CE=",
73+ "lastModified": 1685564631,
74+ "narHash": "sha256-8ywr3AkblY4++3lIVxmrWZFzac7+f32ZEhH/A8pNscI=",
7575 "owner": "NixOS",
7676 "repo": "nixpkgs",
77- "rev": "c8018361fa1d1650ee8d4b96294783cf564e8a7f",
77+ "rev": "4f53efe34b3a8877ac923b9350c874e3dcd5dc0a",
7878 "type": "github"
7979 },
8080 "original": {
--- a/flake.nix
+++ b/flake.nix
@@ -13,13 +13,13 @@
1313 llama-lib = llama-cpp-lib.packages.${system}.default;
1414 llama-cpp-python = pkgs.python310.pkgs.buildPythonPackage rec {
1515 pname = "llama-cpp-python";
16- version = "0.1.38";
16+ version = "0.1.57";
1717
1818 src = pkgs.fetchFromGitHub {
1919 owner = "abetlen";
2020 repo = pname;
2121 rev = "v${version}";
22- sha256 = "sha256-/Ykndsp6puFxa+FSHNln9M2frS7/sMMBJSNJ/mU/CSI=";
22+ sha256 = "sha256-BrR3N+3KRu96j0MIydyrvFb2BN3COeBPISac+ixq3XM=";
2323 };
2424 format = "setuptools";
2525
@@ -28,8 +28,13 @@
2828 sed -i -e "s,_load_shared_library(_lib_base_name),ctypes.CDLL('${llama-lib}/lib/libllama.so')," llama_cpp/llama_cpp.py
2929 '';
3030
31+ # Imports server.app, which needs fancy Starlette packages I'm not
32+ # willing to deal with right now. ~ C.
33+ doCheck = false;
34+
3135 propagatedBuildInputs = with pkgs.python310.pkgs; [
32- typing-extensions
36+ numpy typing-extensions
37+ # anyio fastapi uvicorn
3338 ];
3439 };
3540 sentence-transformers = pkgs.python310.pkgs.buildPythonPackage rec {
@@ -57,12 +62,9 @@
5762
5863 src = pkgs.fetchFromGitHub {
5964 owner = "saharNooby";
60- # owner = "iacore";
6165 repo = "rwkv.cpp";
62- rev = "c736ef5411606b529d3a74c139ee111ef1a28bb9";
63- sha256 = "sha256-zJFmuhyY2kT/WVStBpHSnlmwclXZmVoiFvsurCDHW4E=";
64- # rev = "ae390c6";
65- # sha256 = "sha256-ojDsZgXwd3+E6AGtB/KANGz3Y0W5l9CWGjfhjJEefDQ=";
66+ rev = "363dfb1a061507aee661300fc8e2e153b6e99dc2";
67+ sha256 = "sha256-HlJmXMXSUNgPJN6TSGnNeeBeY3/9HmRH9Qa2d4jPEu4=";
6668 fetchSubmodules = true;
6769 };
6870
@@ -119,7 +121,7 @@
119121 devShells.default = pkgs.mkShell {
120122 name = "zirpu-env";
121123 packages = with pkgs; [
122- git
124+ git gdb
123125 # our Python
124126 py
125127 # catching Python mistakes
--- a/src/agent.py
+++ b/src/agent.py
@@ -17,7 +17,8 @@ from twisted.internet.threads import deferToThread
1717 from twisted.words.protocols.irc import IRCClient
1818
1919 from common import irc_line, Timer, SentenceIndex, breakAt
20-from gens.mawrkov import MawrkovGen, force
20+from gens.camelid import CamelidGen
21+from gens.mawrkov import MawrkovGen
2122 from gens.trans import SentenceEmbed
2223
2324 build_traits = " + ".join
@@ -29,9 +30,16 @@ def load_character(path):
2930 return json.load(handle)
3031
3132 MAX_NEW_TOKENS = 128
32-print("~ Initializing mawrkov adapter…")
33-model_path = sys.argv[1]
34-gen = MawrkovGen(model_path, MAX_NEW_TOKENS)
33+gens = {
34+ "llama": CamelidGen,
35+ "rwkv": MawrkovGen,
36+}
37+model_cls = sys.argv[1]
38+if model_cls not in gens:
39+ raise ValueError("must be one of %r" % tuple(gens.keys()))
40+print("~ Initializing adapter:", model_cls)
41+model_path = sys.argv[2]
42+gen = gens[model_cls](model_path, MAX_NEW_TOKENS)
3543 # Need to protect per-gen data structures in C.
3644 genLock = Lock()
3745 GiB = 1024 ** 3
@@ -50,7 +58,6 @@ prologues = {
5058
5159 class Mind:
5260 currentTag = None
53- logits = state = None
5461
5562 def __init__(self, yarn, name):
5663 self.yarn = yarn
@@ -103,7 +110,7 @@ class Mind:
103110 # Breaking the newline invariant...
104111 d.addCallback(lambda _: deferToThread(self.writeRaw, prefix))
105112 # ...so that this inference happens before the newline...
106- d.addCallback(lambda _: force(tokens, self.logits))
113+ d.addCallback(lambda _: self.yarn.force(tokens))
107114 # XXX decode should be on gen ABC
108115 d.addCallback(lambda t: gen.tokenizer.decode([t]))
109116
@@ -279,7 +286,7 @@ def go():
279286 clock = Clock()
280287 LoopingCall(clock.go).start(60 * 30, now=True)
281288
282- for logpath in sys.argv[2:]:
289+ for logpath in sys.argv[3:]:
283290 character = load_character(logpath)
284291 title = character["title"]
285292 firstStatement = f"I am {title}."
--- a/src/gens/camelid.py
+++ b/src/gens/camelid.py
@@ -8,41 +8,59 @@ class CamelidGen(Gen):
88 model_name = "LLaMA?"
99 model_arch = "LLaMA"
1010 def __init__(self, model_path, max_new_tokens):
11- self.llama = Llama(model_path, n_ctx=1024)
11+ self.llama = Llama(model_path)
1212 self.model_size = os.stat(model_path).st_size * 3 // 2
1313 self.max_new_tokens = max_new_tokens
1414
1515 def footprint(self): return self.model_size
1616 def contextLength(self): return llama_cpp.llama_n_ctx(self.llama.ctx)
1717
18- # XXX doesn't work?
1918 def tokenize(self, s): return self.llama.tokenize(s.encode("utf-8"))
20- def decode(self, ts): return self.llama.detokenize(ts)
19+ def decode(self, ts): return self.llama.detokenize(ts).decode("utf-8")
2120
2221 def fork(self):
2322 return CamelidYarn(self.max_new_tokens, self.llama, self.llama.save_state())
2423
24+yarn_cache = [None]
25+
2526 class CamelidYarn(Yarn):
2627 def __init__(self, max_new_tokens, llama, state):
2728 self.max_new_tokens = max_new_tokens
2829 self.llama = llama
2930 self.state = state
3031
32+ def activate(self):
33+ if yarn_cache[0] is not self:
34+ if yarn_cache[0]: yarn_cache[0].deactivate()
35+ yarn_cache[0] = self
36+ self.llama.load_state(self.state)
37+
38+ def deactivate(self): self.state = self.llama.save_state()
39+
3140 def feedForward(self, tokens):
32- self.llama.load_state(self.state)
41+ self.activate()
3342 self.llama.eval(tokens)
34- self.state = self.llama.save_state()
3543
3644 def complete(self):
37- self.llama.load_state(self.state)
45+ self.activate()
3846 tokens = []
3947 for _ in range(self.max_new_tokens):
4048 token = self.llama.sample()
41- if "\n" in self.llama.detokenize([token]): break
49+ if b"\n" in self.llama.detokenize([token]): break
4250 tokens.append(token)
43- self.state = self.llama.save_state()
4451 return tokens
4552
53+ def force(self, options):
54+ self.activate()
55+ return self.llama.sample(logits_processor=Force(options))
56+
57+class Force:
58+ def __init__(self, options): self.options = options
59+ def __call__(self, input_ids, scores):
60+ # +10 is hopefully not too much.
61+ for option in self.options: scores[option] += 10
62+ return scores
63+
4664 class CamelidEmbed:
4765 def __init__(self, model_path):
4866 self.llama = Llama(model_path, embedding=True)
--- a/src/gens/mawrkov.py
+++ b/src/gens/mawrkov.py
@@ -25,8 +25,10 @@ sampling = bare_import(RWKV_PATH, "sampling")
2525 TOKENIZER_PATH = os.path.join(RWKV, "share", "20B_tokenizer.json")
2626
2727 # Upstream recommends temp 0.7, top_p 0.5
28-TEMPERATURE = 0.8
29-TOP_P = 0.8
28+# TEMPERATURE = 0.8
29+# TOP_P = 0.8
30+TEMPERATURE = 0.9
31+TOP_P = 0.9
3032
3133 class MawrkovGen(Gen):
3234 model_name = "The Pile"
@@ -68,7 +70,7 @@ class MawrkovYarn(Yarn):
6870 self.feedForward([token])
6971 return tokens
7072
71-def force(options, logits):
72- # +10 is hopefully not too much.
73- biases = {opt: 10 for opt in options}
74- return sampling.sample_logits(logits, TEMPERATURE, TOP_P, biases)
73+ def force(self, options):
74+ # +10 is hopefully not too much.
75+ biases = {opt: 10 for opt in options}
76+ return sampling.sample_logits(self.logits, TEMPERATURE, TOP_P, biases)