mirror of
https://github.com/HKUDS/AI-Trader
synced 2026-04-21 13:37:41 +00:00
Merge pull request #190 from HKUDS/fix/issue-186-short-entry-price
Fix short position averaging on add
This commit is contained in:
commit
22c0674296
2 changed files with 97 additions and 2 deletions
|
|
@ -233,10 +233,14 @@ def _update_position_from_signal(
|
|||
if current_qty < 0:
|
||||
# Add to existing short
|
||||
new_qty = current_qty - quantity
|
||||
current_short_qty = abs(current_qty)
|
||||
new_entry_price = (
|
||||
(current_short_qty * row["entry_price"]) + (quantity * price)
|
||||
) / abs(new_qty)
|
||||
cursor.execute("""
|
||||
UPDATE positions SET quantity = ?, opened_at = ?
|
||||
UPDATE positions SET quantity = ?, entry_price = ?, opened_at = ?
|
||||
WHERE id = ?
|
||||
""", (new_qty, executed_at, position_id))
|
||||
""", (new_qty, new_entry_price, executed_at, position_id))
|
||||
print(f"[Position] {symbol}: increased short position to {new_qty}")
|
||||
else:
|
||||
# Create new short position (negative quantity for short)
|
||||
|
|
|
|||
91
service/server/tests/test_services.py
Normal file
91
service/server/tests/test_services.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
import sqlite3
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SERVER_DIR = Path(__file__).resolve().parents[1]
|
||||
if str(SERVER_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(SERVER_DIR))
|
||||
|
||||
from services import _update_position_from_signal
|
||||
|
||||
|
||||
class UpdatePositionFromSignalTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.conn = sqlite3.connect(":memory:")
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
self.cursor = self.conn.cursor()
|
||||
self.cursor.execute(
|
||||
"""
|
||||
CREATE TABLE positions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
agent_id INTEGER NOT NULL,
|
||||
leader_id INTEGER,
|
||||
symbol TEXT NOT NULL,
|
||||
market TEXT NOT NULL DEFAULT 'us-stock',
|
||||
token_id TEXT,
|
||||
outcome TEXT,
|
||||
side TEXT NOT NULL,
|
||||
quantity REAL NOT NULL,
|
||||
entry_price REAL NOT NULL,
|
||||
current_price REAL,
|
||||
opened_at TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.conn.close()
|
||||
|
||||
def test_short_add_updates_weighted_entry_price(self) -> None:
|
||||
self.cursor.execute(
|
||||
"""
|
||||
INSERT INTO positions (
|
||||
agent_id, leader_id, symbol, market, token_id, outcome,
|
||||
side, quantity, entry_price, opened_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
1,
|
||||
None,
|
||||
"BTC",
|
||||
"crypto",
|
||||
None,
|
||||
None,
|
||||
"short",
|
||||
-0.2,
|
||||
100.0,
|
||||
"2026-04-13T14:16:45Z",
|
||||
),
|
||||
)
|
||||
|
||||
_update_position_from_signal(
|
||||
agent_id=1,
|
||||
symbol="BTC",
|
||||
market="crypto",
|
||||
action="short",
|
||||
quantity=0.3,
|
||||
price=120.0,
|
||||
executed_at="2026-04-13T15:16:45Z",
|
||||
cursor=self.cursor,
|
||||
)
|
||||
|
||||
self.cursor.execute(
|
||||
"""
|
||||
SELECT quantity, entry_price, opened_at
|
||||
FROM positions
|
||||
WHERE agent_id = ? AND symbol = ? AND market = ?
|
||||
""",
|
||||
(1, "BTC", "crypto"),
|
||||
)
|
||||
row = self.cursor.fetchone()
|
||||
|
||||
self.assertIsNotNone(row)
|
||||
self.assertAlmostEqual(row["quantity"], -0.5)
|
||||
self.assertAlmostEqual(row["entry_price"], 112.0)
|
||||
self.assertEqual(row["opened_at"], "2026-04-13T15:16:45Z")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in a new issue