buggy for 滑点 and 交易成本

This commit is contained in:
amstrongzyf 2025-11-26 14:28:24 +00:00
parent c3cefa31ae
commit 0639f1ee71
3 changed files with 350 additions and 49 deletions

View file

@ -20,6 +20,80 @@ from tools.price_tools import (get_latest_position, get_open_prices,
mcp = FastMCP("TradeTools")
# ============================================================================
# A股交易成本和滑点配置 (仅适用于 market == "cn")
# ============================================================================
CN_COMMISSION_RATE = 0.003 # 佣金费率: 0.3% (千分之三)
CN_COMMISSION_MIN = 5.0 # 最低佣金: 5元
CN_STAMP_DUTY_RATE = 0.001 # 印花税: 0.1% (千分之一, 仅卖出)
CN_TRANSFER_FEE_RATE = 0.00001 # 过户费: 0.001% (万分之一)
CN_SLIPPAGE_RATE = 0.002 # 固定滑点: 0.2%
def _apply_cn_slippage(base_price: float, action: str) -> float:
"""
应用A股固定滑点
Args:
base_price: 基准价格开盘价
action: 'buy' 'sell'
Returns:
应用滑点后的实际成交价
- 买入: base_price * (1 + 0.002) = 价格更高0.2%
- 卖出: base_price * (1 - 0.002) = 价格更低0.2%
"""
if action == 'buy':
return base_price * (1 + CN_SLIPPAGE_RATE)
else: # sell
return base_price * (1 - CN_SLIPPAGE_RATE)
def _calculate_cn_trading_costs(action: str, price: float, amount: int) -> dict:
"""
计算A股交易成本仅限中国A股市场
Args:
action: 交易类型 'buy' 'sell'
price: 交易价格已应用滑点后的价格
amount: 交易数量
Returns:
包含各项费用的字典:
{
'commission': 佣金,
'stamp_duty': 印花税仅卖出时有值,
'transfer_fee': 过户费,
'total_cost': 总成本
}
费率说明:
- 佣金: 0.3%, 最低5元
- 印花税: 0.1%, 仅卖出时收取
- 过户费: 0.001%
"""
trade_amount = price * amount
# 1. 佣金 (买卖双向收取最低5元)
commission = max(trade_amount * CN_COMMISSION_RATE, CN_COMMISSION_MIN)
# 2. 印花税 (仅卖出时收取)
stamp_duty = trade_amount * CN_STAMP_DUTY_RATE if action == 'sell' else 0.0
# 3. 过户费 (买卖双向收取)
transfer_fee = trade_amount * CN_TRANSFER_FEE_RATE
# 4. 总成本
total_cost = commission + stamp_duty + transfer_fee
return {
'commission': round(commission, 2),
'stamp_duty': round(stamp_duty, 2),
'transfer_fee': round(transfer_fee, 2),
'total_cost': round(total_cost, 2)
}
def _position_lock(signature: str):
"""Context manager for file-based lock to serialize position updates per signature."""
class _Lock:
@ -159,17 +233,32 @@ def buy(symbol: str, amount: int) -> Dict[str, Any]:
"market": market,
}
# Step 4: Validate buy conditions
# Calculate cash required for purchase: stock price × buy quantity
# Step 4: Apply slippage and calculate trading costs (CN market only)
# For A-shares, apply 0.2% slippage and calculate trading costs
actual_price = this_symbol_price # Default: use opening price as-is
trading_costs = None
if market == "cn":
# Apply fixed slippage: buy price is 0.2% higher
actual_price = _apply_cn_slippage(this_symbol_price, action='buy')
# Calculate trading costs for A-shares
trading_costs = _calculate_cn_trading_costs(action='buy', price=actual_price, amount=amount)
# Step 5: Validate buy conditions
# Calculate cash required: stock price × quantity + trading costs (if CN market)
trade_amount = actual_price * amount
total_cost = trade_amount + (trading_costs['total_cost'] if trading_costs else 0)
try:
cash_left = current_position["CASH"] - this_symbol_price * amount
cash_left = current_position["CASH"] - total_cost
except Exception as e:
# Defensive: if any unexpected structure, surface a clear error
return {
"error": f"Failed to compute cash after purchase: {e}",
"symbol": symbol,
"date": today_date,
"price": this_symbol_price,
"price": actual_price,
"amount": amount,
"position_keys": list(current_position.keys()),
}
@ -179,23 +268,25 @@ def buy(symbol: str, amount: int) -> Dict[str, Any]:
# Insufficient cash, return error message
return {
"error": "Insufficient cash! This action will not be allowed.",
"required_cash": this_symbol_price * amount,
"required_cash": total_cost,
"trade_amount": trade_amount,
"trading_costs": trading_costs['total_cost'] if trading_costs else 0,
"cash_available": current_position.get("CASH", 0),
"symbol": symbol,
"date": today_date,
}
else:
# Step 5: Execute buy operation, update position
# Step 6: Execute buy operation, update position
# Create a copy of current position to avoid directly modifying original data
new_position = current_position.copy()
# Decrease cash balance
# Decrease cash balance (including trading costs for CN market)
new_position["CASH"] = cash_left
# Increase stock position quantity
new_position[symbol] = new_position.get(symbol, 0) + amount
# Step 6: Record transaction to position.jsonl file
# Step 7: Record transaction to position.jsonl file
# Build file path: {project_root}/data/{log_path}/{signature}/position/position.jsonl
# Use append mode ("a") to write new transaction record
# Each operation ID increments by 1, ensuring uniqueness of operation sequence
@ -203,28 +294,39 @@ def buy(symbol: str, amount: int) -> Dict[str, Any]:
if log_path.startswith("./data/"):
log_path = log_path[7:] # Remove "./data/" prefix
position_file_path = os.path.join(project_root, "data", log_path, signature, "position", "position.jsonl")
# Prepare action record
action_record = {
"action": "buy",
"symbol": symbol,
"amount": amount,
"price": round(actual_price, 2),
}
# Add trading costs details for CN market
if trading_costs:
action_record["trading_costs"] = trading_costs
action_record["base_price"] = round(this_symbol_price, 2)
action_record["slippage"] = round(actual_price - this_symbol_price, 2)
with open(position_file_path, "a") as f:
# Write JSON format transaction record, containing date, operation ID, transaction details and updated position
print(
f"Writing to position.jsonl: {json.dumps({'date': today_date, 'id': current_action_id + 1, 'this_action':{'action':'buy','symbol':symbol,'amount':amount},'positions': new_position})}"
)
f.write(
json.dumps(
{
"date": today_date,
"id": current_action_id + 1,
"this_action": {"action": "buy", "symbol": symbol, "amount": amount},
"positions": new_position,
}
)
+ "\n"
)
# Step 7: Return updated position
record = {
"date": today_date,
"id": current_action_id + 1,
"this_action": action_record,
"positions": new_position,
}
print(f"Writing to position.jsonl: {json.dumps(record)}")
f.write(json.dumps(record) + "\n")
# Step 8: Return updated position
write_config_value("IF_TRADE", True)
print("IF_TRADE", get_config_value("IF_TRADE"))
return new_position
def _get_today_buy_amount(symbol: str, today_date: str, signature: str) -> int:
"""
Helper function to get the total amount bought today for T+1 restriction check
@ -391,18 +493,32 @@ def sell(symbol: str, amount: int) -> Dict[str, Any]:
"date": today_date,
}
# Step 5: Execute sell operation, update position
# Step 5: Apply slippage and calculate trading costs (CN market only)
# For A-shares, apply 0.2% slippage and calculate trading costs
actual_price = this_symbol_price # Default: use opening price as-is
trading_costs = None
if market == "cn":
# Apply fixed slippage: sell price is 0.2% lower
actual_price = _apply_cn_slippage(this_symbol_price, action='sell')
# Calculate trading costs for A-shares (includes stamp duty for selling)
trading_costs = _calculate_cn_trading_costs(action='sell', price=actual_price, amount=amount)
# Step 6: Execute sell operation, update position
# Create a copy of current position to avoid directly modifying original data
new_position = current_position.copy()
# Decrease stock position quantity
new_position[symbol] -= amount
# Increase cash balance: sell price × sell quantity
# Increase cash balance: sell price × sell quantity - trading costs (if CN market)
# Use get method to ensure CASH field exists, default to 0 if not present
new_position["CASH"] = new_position.get("CASH", 0) + this_symbol_price * amount
trade_amount = actual_price * amount
net_proceeds = trade_amount - (trading_costs['total_cost'] if trading_costs else 0)
new_position["CASH"] = new_position.get("CASH", 0) + net_proceeds
# Step 6: Record transaction to position.jsonl file
# Step 7: Record transaction to position.jsonl file
# Build file path: {project_root}/data/{log_path}/{signature}/position/position.jsonl
# Use append mode ("a") to write new transaction record
# Each operation ID increments by 1, ensuring uniqueness of operation sequence
@ -410,28 +526,39 @@ def sell(symbol: str, amount: int) -> Dict[str, Any]:
if log_path.startswith("./data/"):
log_path = log_path[7:] # Remove "./data/" prefix
position_file_path = os.path.join(project_root, "data", log_path, signature, "position", "position.jsonl")
# Prepare action record
action_record = {
"action": "sell",
"symbol": symbol,
"amount": amount,
"price": round(actual_price, 2),
}
# Add trading costs details for CN market
if trading_costs:
action_record["trading_costs"] = trading_costs
action_record["base_price"] = round(this_symbol_price, 2)
action_record["slippage"] = round(actual_price - this_symbol_price, 2)
action_record["net_proceeds"] = round(net_proceeds, 2)
with open(position_file_path, "a") as f:
# Write JSON format transaction record, containing date, operation ID and updated position
print(
f"Writing to position.jsonl: {json.dumps({'date': today_date, 'id': current_action_id + 1, 'this_action':{'action':'sell','symbol':symbol,'amount':amount},'positions': new_position})}"
)
f.write(
json.dumps(
{
"date": today_date,
"id": current_action_id + 1,
"this_action": {"action": "sell", "symbol": symbol, "amount": amount},
"positions": new_position,
}
)
+ "\n"
)
record = {
"date": today_date,
"id": current_action_id + 1,
"this_action": action_record,
"positions": new_position,
}
print(f"Writing to position.jsonl: {json.dumps(record)}")
f.write(json.dumps(record) + "\n")
# Step 7: Return updated position
# Step 8: Return updated position
write_config_value("IF_TRADE", True)
return new_position
if __name__ == "__main__":
# new_result = buy("AAPL", 1)
# print(new_result)

View file

@ -211,6 +211,69 @@ def calculate_asset_value(position, date, price_data, market='us'):
return total_value
def calculate_cn_legacy_cost_adjustment(position, price_data):
"""
Calculate trading cost adjustment for A-share legacy data (without trading_costs field).
Args:
position: Position entry from position.jsonl
price_data: Price data cache for getting historical prices
Returns:
float: Cost adjustment amount to subtract from CASH
"""
if 'this_action' not in position:
return 0.0
action_data = position['this_action']
symbol = action_data.get('symbol', '')
action = action_data.get('action', '')
# Only process A-share trades without trading_costs
is_astock = symbol.endswith('.SH') or symbol.endswith('.SZ')
has_costs = 'trading_costs' in action_data
if not (is_astock and not has_costs and action in ['buy', 'sell']):
return 0.0
amount = action_data.get('amount', 0)
if amount <= 0:
return 0.0
# Get price from price_data
date = position['date']
price = get_closing_price(symbol, date, price_data, 'cn')
if not price:
return 0.0
# Apply slippage (0.2% fixed)
if action == 'buy':
actual_price = price * 1.002 # +0.2%
else: # sell
actual_price = price * 0.998 # -0.2%
# Calculate trading costs
trade_amount = actual_price * amount
# Commission: 0.3%, minimum 5元
commission = max(trade_amount * 0.003, 5.0)
# Stamp duty: 0.1%, only for selling
stamp_duty = trade_amount * 0.001 if action == 'sell' else 0.0
# Transfer fee: 0.001%
transfer_fee = trade_amount * 0.00001
total_cost = commission + stamp_duty + transfer_fee
# Slippage cost
slippage_cost = abs(actual_price - price) * amount
# Total adjustment
return total_cost + slippage_cost
def process_agent_data_us(agent_config, market_config):
"""Process agent data for US market."""
agent_folder = agent_config['folder']
@ -326,8 +389,25 @@ def process_agent_data_cn(agent_config, market_config, price_cache):
# For hourly data, just return all positions without date filling
if preserve_hourly:
asset_history = []
cumulative_cost_adjustment = 0 # Track cost adjustments for legacy data
legacy_detected = False
for position in unique_positions:
asset_value = calculate_asset_value(position, position['dateKey'], price_cache, 'cn')
# Calculate cost adjustment for this position (if legacy A-share data)
cost_adjustment = calculate_cn_legacy_cost_adjustment(position, price_cache)
if cost_adjustment > 0 and not legacy_detected:
legacy_detected = True
print(f" 📊 Detected legacy A-share data, applying cost adjustments...")
cumulative_cost_adjustment += cost_adjustment
# Calculate asset value with adjusted CASH
adjusted_position = position.copy()
adjusted_positions_dict = adjusted_position['positions'].copy()
adjusted_positions_dict['CASH'] -= cumulative_cost_adjustment
adjusted_position['positions'] = adjusted_positions_dict
asset_value = calculate_asset_value(adjusted_position, position['dateKey'], price_cache, 'cn')
if asset_value is not None:
asset_history.append({
'date': position['dateKey'],
@ -342,12 +422,16 @@ def process_agent_data_cn(agent_config, market_config, price_cache):
result = {
'name': agent_folder,
'positions': [{'date': p['dateKey'], 'id': p['id'], 'positions': p['positions']} for p in unique_positions],
'positions': [{'date': p['dateKey'], 'id': p['id'], 'positions': p['positions'], 'this_action': p.get('this_action')} for p in unique_positions],
'assetHistory': asset_history,
'initialValue': asset_history[0]['value'] if asset_history else 10000,
'currentValue': asset_history[-1]['value'] if asset_history else 0,
'return': ((asset_history[-1]['value'] - asset_history[0]['value']) / asset_history[0]['value'] * 100) if asset_history else 0
'return': ((asset_history[-1]['value'] - asset_history[0]['value']) / asset_history[0]['value'] * 100) if asset_history else 0,
'currency': 'CNY'
}
if legacy_detected:
print(f" ✅ Applied total cost adjustment of {cumulative_cost_adjustment:.2f}")
print(f"{len(result['positions'])} positions, {len(asset_history)} data points (hourly)")
return result
@ -359,11 +443,26 @@ def process_agent_data_cn(agent_config, market_config, price_cache):
# Create position map for quick lookup
position_map = {pos['dateKey']: pos for pos in unique_positions}
# Calculate cumulative cost adjustments for legacy data
cumulative_costs_by_date = {}
cumulative_cost = 0
legacy_detected = False
for pos_key in sorted(position_map.keys()):
position = position_map[pos_key]
cost_adjustment = calculate_cn_legacy_cost_adjustment(position, price_cache)
if cost_adjustment > 0 and not legacy_detected:
legacy_detected = True
print(f" 📊 Detected legacy A-share data, applying cost adjustments...")
cumulative_cost += cost_adjustment
cumulative_costs_by_date[pos_key] = cumulative_cost
# Fill all dates in range (skip weekends)
asset_history = []
current_position = None
current_date = start_date
current_cumulative_cost = 0
while current_date <= end_date:
# Skip weekends
@ -373,10 +472,17 @@ def process_agent_data_cn(agent_config, market_config, price_cache):
# Use position for this date if exists, otherwise use last known position
if date_str in position_map:
current_position = position_map[date_str]
current_cumulative_cost = cumulative_costs_by_date.get(date_str, current_cumulative_cost)
if current_position:
# Calculate asset value
asset_value = calculate_asset_value(current_position, date_str, price_cache, 'cn')
# Calculate asset value with adjusted CASH
adjusted_position = current_position.copy()
adjusted_positions_dict = adjusted_position['positions'].copy()
adjusted_positions_dict['CASH'] -= current_cumulative_cost
adjusted_position['positions'] = adjusted_positions_dict
asset_value = calculate_asset_value(adjusted_position, date_str, price_cache, 'cn')
if asset_value is not None:
asset_history.append({
@ -400,7 +506,8 @@ def process_agent_data_cn(agent_config, market_config, price_cache):
'assetHistory': asset_history,
'initialValue': asset_history[0]['value'] if asset_history else 10000,
'currentValue': asset_history[-1]['value'] if asset_history else 0,
'return': ((asset_history[-1]['value'] - asset_history[0]['value']) / asset_history[0]['value'] * 100) if asset_history else 0
'return': ((asset_history[-1]['value'] - asset_history[0]['value']) / asset_history[0]['value'] * 100) if asset_history else 0,
'currency': 'CNY'
}
print(f"{len(positions)} positions, {len(asset_history)} data points")

View file

@ -146,18 +146,82 @@ def load_all_price_files(data_dir, is_crypto=False, is_astock=False):
def calculate_portfolio_values(positions, price_data, is_crypto=False, verbose=True):
"""
Calculate portfolio value at each timestamp.
For A-share legacy data (without trading_costs field), this function will
retroactively calculate and adjust for trading costs and slippage.
Returns:
DataFrame with columns: date, cash, stock_value, total_value
"""
portfolio_values = []
missing_prices = set()
cumulative_cost_adjustment = 0 # Track cumulative cost adjustments for legacy data
legacy_data_detected = False
for entry in positions:
date = entry['date']
pos = entry['positions']
pos = entry['positions'].copy() # Make a copy to avoid modifying original
# Check if this is A-share legacy data (no trading_costs field)
if 'this_action' in entry:
action_data = entry['this_action']
symbol = action_data.get('symbol', '')
action = action_data.get('action', '')
# Detect A-share symbols
is_astock = symbol.endswith('.SH') or symbol.endswith('.SZ')
has_costs = 'trading_costs' in action_data
# If A-share AND no trading costs AND is a trade action
if is_astock and not has_costs and action in ['buy', 'sell']:
if not legacy_data_detected:
legacy_data_detected = True
if verbose:
print(f"📊 Detected legacy A-share data (without trading costs). Applying retroactive cost adjustments...")
amount = action_data.get('amount', 0)
if amount > 0:
# Get price from price_data
price = get_price_at_date(price_data, symbol, date, is_crypto)
if price:
# Apply slippage (0.2% fixed)
if action == 'buy':
actual_price = price * 1.002 # +0.2%
else: # sell
actual_price = price * 0.998 # -0.2%
# Calculate trading costs
trade_amount = actual_price * amount
# Commission: 0.3%, minimum 5元
commission = max(trade_amount * 0.003, 5.0)
# Stamp duty: 0.1%, only for selling
stamp_duty = trade_amount * 0.001 if action == 'sell' else 0.0
# Transfer fee: 0.001%
transfer_fee = trade_amount * 0.00001
total_cost = commission + stamp_duty + transfer_fee
# Slippage cost
slippage_cost = abs(actual_price - price) * amount
# Total adjustment: costs + slippage
# For both buy and sell, these costs reduce cash
adjustment = total_cost + slippage_cost
cumulative_cost_adjustment += adjustment
if verbose and cumulative_cost_adjustment > 0:
print(f" {date} {action} {symbol}: cost adjustment = {adjustment:.2f} (cumulative: {cumulative_cost_adjustment:.2f})")
# Adjust CASH for legacy data
cash = pos.get('CASH', 0)
if cumulative_cost_adjustment > 0:
cash -= cumulative_cost_adjustment
stock_value = 0
# Calculate value of all stock holdings
@ -185,6 +249,9 @@ def calculate_portfolio_values(positions, price_data, is_crypto=False, verbose=T
df = pd.DataFrame(portfolio_values)
df['date'] = pd.to_datetime(df['date'])
if legacy_data_detected and verbose:
print(f"✅ Applied total cost adjustment of {cumulative_cost_adjustment:.2f} for legacy A-share data")
if not verbose and missing_prices:
print(f"Warning: {len(missing_prices)} missing price entries (use --verbose to see details)")