策略:纯信号模式(不定期换仓,止损止盈卖出后因子选股补仓) 股票池:沪深300+中证500(800+只) 止损-8% / 止盈+25% / 单只20% / 最多5只 5年回测:+371.7%,夏普0.82,年化21.3% 组件: - engine.py: 核心交易引擎 + 因子评分 + 数据管理 - scheduler.py: APScheduler定时调度 + HTTP状态接口 - trade_tool.py: 命令行工具 - config.json: 策略参数配置 - Dockerfile + docker-compose.yml: 容器化部署 日志系统: - 文件日志(按日轮转,保留90天) - SQLite: trades/daily_log/signal_log/system_log
895 lines
34 KiB
Python
895 lines
34 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
纯信号交易引擎 v2.0
|
||
==================
|
||
固定参数纯信号策略:
|
||
- 不定期换仓,只根据止损/止盈卖出
|
||
- 卖出后立即因子选股补仓
|
||
- 每10天扫描补充新标的
|
||
- 固定止损-8% / 止盈+25%
|
||
- 单只20%仓位 / 最多5只
|
||
"""
|
||
|
||
import os, sys, json, time, math, sqlite3, logging, traceback
|
||
from datetime import datetime, timedelta
|
||
from collections import defaultdict
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
import numpy as np
|
||
import tushare as ts
|
||
|
||
# ─── 路径 ───
|
||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
APP_DIR = os.path.dirname(BASE_DIR)
|
||
DATA_DIR = os.path.join(APP_DIR, 'data')
|
||
LOG_DIR = os.path.join(APP_DIR, 'logs')
|
||
CONFIG_DIR = os.path.join(APP_DIR, 'config')
|
||
|
||
for d in [DATA_DIR, LOG_DIR]:
|
||
os.makedirs(d, exist_ok=True)
|
||
|
||
# ─── 日志 ───
|
||
def setup_logger(name, log_file=None, level=logging.INFO):
|
||
logger = logging.getLogger(name)
|
||
logger.setLevel(level)
|
||
fmt = logging.Formatter(
|
||
'%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
||
datefmt='%Y-%m-%d %H:%M:%S'
|
||
)
|
||
# Console
|
||
ch = logging.StreamHandler(sys.stdout)
|
||
ch.setFormatter(fmt)
|
||
logger.addHandler(ch)
|
||
# File
|
||
if log_file:
|
||
from logging.handlers import TimedRotatingFileHandler
|
||
fh = TimedRotatingFileHandler(
|
||
os.path.join(LOG_DIR, log_file),
|
||
when='midnight', backupCount=90, encoding='utf-8'
|
||
)
|
||
fh.setFormatter(fmt)
|
||
logger.addHandler(fh)
|
||
return logger
|
||
|
||
log = setup_logger('engine', 'engine.log')
|
||
|
||
# ─── 配置加载 ───
|
||
def load_config():
|
||
cfg_path = os.path.join(CONFIG_DIR, 'config.json')
|
||
if os.path.exists(cfg_path):
|
||
with open(cfg_path) as f:
|
||
cfg = json.load(f)
|
||
# 环境变量替换
|
||
token = os.environ.get('TUSHARE_TOKEN', '')
|
||
if token:
|
||
cfg['tushare_token'] = token
|
||
return cfg
|
||
return {
|
||
'tushare_token': os.environ.get('TUSHARE_TOKEN', ''),
|
||
'initial_cash': 100000,
|
||
'position': {
|
||
'max_position_pct': 0.20, 'max_holdings': 5,
|
||
'top_n_buy': 3, 'min_buy_amount': 5000
|
||
},
|
||
'exit': {'stop_loss_pct': -0.08, 'take_profit_pct': 0.25},
|
||
'fee': {'commission_rate': 0.0003, 'min_commission': 5, 'stamp_tax_rate': 0.001},
|
||
'scan': {'interval_days': 10, 'factor_pool_size': 50},
|
||
}
|
||
|
||
|
||
# =====================================================================
|
||
# 因子引擎
|
||
# =====================================================================
|
||
class FactorEngine:
|
||
"""纯量价因子评分"""
|
||
|
||
FACTOR_DEFS = {
|
||
'hist_vol_20d': {'weight': 0.20, 'direction': 1},
|
||
'atr_ratio': {'weight': 0.18, 'direction': 1},
|
||
'volume_ratio': {'weight': 0.10, 'direction': -1},
|
||
'vol_breakout': {'weight': 0.08, 'direction': -1},
|
||
'macd_hist_trend': {'weight': 0.06, 'direction': -1},
|
||
'mom_60d': {'weight': 0.04, 'direction': -1},
|
||
'ma20_slope': {'weight': 0.04, 'direction': 1},
|
||
'rsi_14': {'weight': 0.10, 'direction': -1},
|
||
'ma_dist': {'weight': 0.10, 'direction': -1},
|
||
'turnover_avg': {'weight': 0.10, 'direction': 1},
|
||
}
|
||
|
||
def score_all(self, stock_data, date, date_to_idx, pool_size=50):
|
||
"""评分所有股票,返回 [(code, score)] 排序后Top pool_size"""
|
||
raw_scores = {}
|
||
for code, sd in stock_data.items():
|
||
idx_map = date_to_idx.get(code, {})
|
||
if date not in idx_map:
|
||
continue
|
||
eidx = idx_map[date]
|
||
if eidx < 60:
|
||
continue
|
||
|
||
closes = sd['closes'][:eidx+1]
|
||
if len(closes) < 60:
|
||
continue
|
||
|
||
try:
|
||
c = np.array(closes, dtype=float)
|
||
highs = np.array(sd['highs'][:eidx+1], dtype=float)
|
||
lows = np.array(sd['lows'][:eidx+1], dtype=float)
|
||
volumes = np.array(sd['volumes'][:eidx+1], dtype=float)
|
||
|
||
n = len(c)
|
||
factors = {}
|
||
|
||
# hist_vol_20d
|
||
if n >= 21:
|
||
rets = np.diff(c[-21:]) / c[-21:-1]
|
||
factors['hist_vol_20d'] = np.std(rets)
|
||
|
||
# atr_ratio
|
||
if n >= 15:
|
||
prev_close = np.append(c[-15], c[-14:-1])
|
||
tr = np.maximum(highs[-14:] - lows[-14:],
|
||
np.maximum(np.abs(highs[-14:] - prev_close),
|
||
np.abs(lows[-14:] - prev_close)))
|
||
atr14 = np.mean(tr)
|
||
if c[-1] > 0:
|
||
factors['atr_ratio'] = atr14 / c[-1]
|
||
|
||
# volume_ratio
|
||
if n >= 21:
|
||
vol5 = np.mean(volumes[-5:])
|
||
vol20 = np.mean(volumes[-20:])
|
||
if vol20 > 0:
|
||
factors['volume_ratio'] = vol5 / vol20
|
||
|
||
# vol_breakout
|
||
if n >= 21:
|
||
ma_vol = np.mean(volumes[-20:])
|
||
if ma_vol > 0:
|
||
factors['vol_breakout'] = volumes[-1] / ma_vol
|
||
|
||
# macd_hist_trend
|
||
if n >= 35:
|
||
ema12 = c[-1]
|
||
ema26 = c[-1]
|
||
for i in range(max(0, n-35), n):
|
||
ema12 = c[i] * (2/13) + ema12 * (11/13)
|
||
ema26 = c[i] * (2/27) + ema26 * (25/27)
|
||
dif = ema12 - ema26
|
||
factors['macd_hist_trend'] = dif / c[-1] if c[-1] > 0 else 0
|
||
|
||
# mom_60d
|
||
if n >= 61:
|
||
factors['mom_60d'] = (c[-1] / c[-61] - 1)
|
||
|
||
# ma20_slope
|
||
if n >= 25:
|
||
ma20_now = np.mean(c[-20:])
|
||
ma20_5 = np.mean(c[-25:-5])
|
||
if ma20_5 > 0:
|
||
factors['ma20_slope'] = (ma20_now / ma20_5 - 1)
|
||
|
||
# rsi_14
|
||
if n >= 16:
|
||
deltas = np.diff(c[-15:])
|
||
gains = np.where(deltas > 0, deltas, 0)
|
||
losses = np.where(deltas < 0, -deltas, 0)
|
||
avg_gain = np.mean(gains)
|
||
avg_loss = np.mean(losses)
|
||
if avg_loss > 0:
|
||
rs = avg_gain / avg_loss
|
||
factors['rsi_14'] = 100 - 100 / (1 + rs)
|
||
else:
|
||
factors['rsi_14'] = 100
|
||
|
||
# ma_dist (距离MA60)
|
||
if n >= 61:
|
||
ma60 = np.mean(c[-60:])
|
||
if ma60 > 0:
|
||
factors['ma_dist'] = (c[-1] / ma60 - 1)
|
||
|
||
# turnover_avg
|
||
if n >= 11 and len(sd.get('amounts', [])) >= n:
|
||
amounts = sd['amounts'][:eidx+1]
|
||
avg_turn = np.mean([a/(c[i]*1e7) if c[i] > 0 and i < len(amounts) else 0
|
||
for i, a in enumerate(amounts[-10:])])
|
||
factors['turnover_avg'] = avg_turn
|
||
|
||
# 计算综合评分
|
||
if len(factors) >= 5:
|
||
score = 0
|
||
for fname, fdef in self.FACTOR_DEFS.items():
|
||
if fname in factors:
|
||
val = factors[fname]
|
||
if np.isnan(val) or np.isinf(val):
|
||
continue
|
||
score += val * fdef['weight'] * fdef['direction']
|
||
raw_scores[code] = score
|
||
|
||
except Exception:
|
||
continue
|
||
|
||
ranked = sorted(raw_scores.items(), key=lambda x: -x[1])
|
||
return ranked[:pool_size]
|
||
|
||
|
||
# =====================================================================
|
||
# 数据管理
|
||
# =====================================================================
|
||
class DataManager:
|
||
"""Tushare数据拉取与缓存"""
|
||
|
||
def __init__(self, token):
|
||
self.pro = ts.pro_api(token)
|
||
self._stock_basic = None
|
||
self._index_daily = {}
|
||
self._pool_codes = None
|
||
|
||
def get_stock_basic(self):
|
||
if self._stock_basic is None:
|
||
log.info("拉取股票基础信息...")
|
||
self._stock_basic = self.pro.stock_basic(
|
||
exchange='', list_status='L',
|
||
fields='ts_code,symbol,name,industry,list_date'
|
||
)
|
||
return self._stock_basic
|
||
|
||
def get_hs300_codes(self):
|
||
log.info("拉取沪深300成分股...")
|
||
df = self.pro.index_weight(index_code='399300.SZ', start_date='20250101')
|
||
if df is None or df.empty:
|
||
df = self.pro.index_weight(index_code='399300.SZ')
|
||
codes = df['con_code'].unique().tolist()
|
||
log.info(f" 沪深300: {len(codes)}只")
|
||
return codes
|
||
|
||
def get_zz500_codes(self):
|
||
log.info("拉取中证500成分股...")
|
||
df = self.pro.index_weight(index_code='000905.SH', start_date='20250101')
|
||
if df is None or df.empty:
|
||
df = self.pro.index_weight(index_code='000905.SH')
|
||
codes = df['con_code'].unique().tolist()
|
||
log.info(f" 中证500: {len(codes)}只")
|
||
return codes
|
||
|
||
def get_pool_codes(self):
|
||
"""获取股票池(沪深300+中证500)"""
|
||
if self._pool_codes is not None:
|
||
return self._pool_codes
|
||
|
||
# 尝试加载本地缓存
|
||
cache_file = os.path.join(DATA_DIR, 'pool_codes.json')
|
||
if os.path.exists(cache_file):
|
||
age = time.time() - os.path.getmtime(cache_file)
|
||
if age < 7 * 86400: # 7天缓存
|
||
with open(cache_file) as f:
|
||
self._pool_codes = json.load(f)
|
||
log.info(f"股票池缓存命中: {len(self._pool_codes)}只")
|
||
return self._pool_codes
|
||
|
||
hs300 = self.get_hs300_codes()
|
||
zz500 = self.get_zz500_codes()
|
||
all_codes = list(set(hs300 + zz500))
|
||
|
||
# 过滤ST和退市
|
||
basic = self.get_stock_basic()
|
||
st_names = basic[basic['name'].str.contains('ST|退', na=False)]['ts_code'].tolist()
|
||
all_codes = [c for c in all_codes if c not in st_names]
|
||
|
||
self._pool_codes = all_codes
|
||
with open(cache_file, 'w') as f:
|
||
json.dump(all_codes, f)
|
||
log.info(f"股票池: {len(all_codes)}只(去重后)")
|
||
return all_codes
|
||
|
||
def get_trade_dates(self, start_date=None, end_date=None):
|
||
if end_date is None:
|
||
end_date = datetime.now().strftime('%Y%m%d')
|
||
if start_date is None:
|
||
start_date = (datetime.now() - timedelta(days=60)).strftime('%Y%m%d')
|
||
df = self.pro.trade_cal(exchange='SSE', start_date=start_date, end_date=end_date, is_open=1)
|
||
return sorted(df['cal_date'].tolist())
|
||
|
||
def get_daily_batch(self, trade_date):
|
||
"""批量获取某天全市场日线"""
|
||
try:
|
||
df = self.pro.daily(trade_date=trade_date)
|
||
if df is None or df.empty:
|
||
return {}
|
||
result = {}
|
||
for _, row in df.iterrows():
|
||
result[row['ts_code']] = {
|
||
'open': float(row.get('open', 0)),
|
||
'high': float(row.get('high', 0)),
|
||
'low': float(row.get('low', 0)),
|
||
'close': float(row.get('close', 0)),
|
||
'volume': float(row.get('vol', 0)),
|
||
'amount': float(row.get('amount', 0)),
|
||
}
|
||
return result
|
||
except Exception as e:
|
||
log.warning(f"拉取{trade_date}日线失败: {e}")
|
||
return {}
|
||
|
||
def get_stock_daily(self, ts_code, start_date, end_date):
|
||
"""获取单只股票历史日线"""
|
||
try:
|
||
df = self.pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||
if df is None or df.empty:
|
||
return None
|
||
df = df.sort_values('trade_date')
|
||
return {
|
||
'dates': df['trade_date'].tolist(),
|
||
'opens': df['open'].astype(float).tolist(),
|
||
'highs': df['high'].astype(float).tolist(),
|
||
'lows': df['low'].astype(float).tolist(),
|
||
'closes': df['close'].astype(float).tolist(),
|
||
'volumes': df['vol'].astype(float).tolist(),
|
||
'amounts': df['amount'].astype(float).tolist(),
|
||
}
|
||
except Exception as e:
|
||
log.debug(f"拉取{ts_code}失败: {e}")
|
||
return None
|
||
|
||
def get_index_daily(self, ts_code='000300.SH', days=120):
|
||
"""获取指数日线"""
|
||
end = datetime.now().strftime('%Y%m%d')
|
||
start = (datetime.now() - timedelta(days=int(days*1.5))).strftime('%Y%m%d')
|
||
try:
|
||
df = self.pro.index_daily(ts_code=ts_code, start_date=start, end_date=end)
|
||
if df is None or df.empty:
|
||
return [], []
|
||
df = df.sort_values('trade_date')
|
||
return df['close'].astype(float).tolist(), df['vol'].astype(float).tolist()
|
||
except:
|
||
return [], []
|
||
|
||
def load_all_history(self, days=300):
|
||
"""加载所有股票池成分的历史数据"""
|
||
pool = self.get_pool_codes()
|
||
end = datetime.now().strftime('%Y%m%d')
|
||
start = (datetime.now() - timedelta(days=int(days*1.5))).strftime('%Y%m%d')
|
||
|
||
trade_dates = self.get_trade_dates(start, end)
|
||
if not trade_dates:
|
||
log.error("无法获取交易日历")
|
||
return {}, []
|
||
|
||
log.info(f"交易日历: {len(trade_dates)}天 ({trade_dates[0]}~{trade_dates[-1]})")
|
||
|
||
# 按天批量拉取
|
||
stock_data = {}
|
||
for i, td in enumerate(trade_dates):
|
||
daily = self.get_daily_batch(td)
|
||
if not daily:
|
||
continue
|
||
for code, row in daily.items():
|
||
if code not in pool:
|
||
continue
|
||
if code not in stock_data:
|
||
stock_data[code] = {
|
||
'dates': [], 'opens': [], 'highs': [], 'lows': [],
|
||
'closes': [], 'volumes': [], 'amounts': [], 'name': code
|
||
}
|
||
sd = stock_data[code]
|
||
sd['dates'].append(td)
|
||
sd['opens'].append(row['open'])
|
||
sd['highs'].append(row['high'])
|
||
sd['lows'].append(row['low'])
|
||
sd['closes'].append(row['close'])
|
||
sd['volumes'].append(row['volume'])
|
||
sd['amounts'].append(row['amount'])
|
||
|
||
if (i+1) % 50 == 0:
|
||
log.info(f" [{i+1}/{len(trade_dates)}] {len(stock_data)}只")
|
||
|
||
log.info(f"数据加载完成: {len(stock_data)}只 × {len(trade_dates)}天")
|
||
|
||
# 填充中文名
|
||
try:
|
||
basic = self.get_stock_basic()
|
||
name_map = dict(zip(basic['ts_code'], basic['name']))
|
||
for code in stock_data:
|
||
if code in name_map:
|
||
stock_data[code]['name'] = name_map[code]
|
||
except:
|
||
pass
|
||
|
||
return stock_data, trade_dates
|
||
|
||
|
||
# =====================================================================
|
||
# 交易引擎
|
||
# =====================================================================
|
||
class TradingEngine:
|
||
"""纯信号交易引擎"""
|
||
|
||
def __init__(self, config=None):
|
||
self.cfg = config or load_config()
|
||
self.token = self.cfg.get('tushare_token', '')
|
||
self.dm = DataManager(self.token)
|
||
self.factor = FactorEngine()
|
||
|
||
# 策略参数
|
||
pos_cfg = self.cfg.get('position', {})
|
||
exit_cfg = self.cfg.get('exit', {})
|
||
fee_cfg = self.cfg.get('fee', {})
|
||
scan_cfg = self.cfg.get('scan', {})
|
||
|
||
self.initial_cash = float(self.cfg.get('initial_cash', 100000))
|
||
self.max_position_pct = pos_cfg.get('max_position_pct', 0.20)
|
||
self.max_holdings = pos_cfg.get('max_holdings', 5)
|
||
self.top_n_buy = pos_cfg.get('top_n_buy', 3)
|
||
self.stop_loss_pct = exit_cfg.get('stop_loss_pct', -0.08)
|
||
self.take_profit_pct = exit_cfg.get('take_profit_pct', 0.25)
|
||
self.commission_rate = fee_cfg.get('commission_rate', 0.0003)
|
||
self.min_commission = fee_cfg.get('min_commission', 5)
|
||
self.stamp_tax_rate = fee_cfg.get('stamp_tax_rate', 0.001)
|
||
self.scan_interval = scan_cfg.get('interval_days', 10)
|
||
self.pool_size = scan_cfg.get('factor_pool_size', 50)
|
||
|
||
# 状态
|
||
self.state = self._load_state()
|
||
self._init_db()
|
||
|
||
def _state_file(self):
|
||
return os.path.join(DATA_DIR, 'state.json')
|
||
|
||
def _db_file(self):
|
||
return os.path.join(DATA_DIR, 'trading.db')
|
||
|
||
def _load_state(self):
|
||
path = self._state_file()
|
||
if os.path.exists(path):
|
||
with open(path) as f:
|
||
return json.load(f)
|
||
return {
|
||
'cash': self.initial_cash,
|
||
'positions': {},
|
||
'last_scan_date': None,
|
||
'last_scan_idx': -self.scan_interval,
|
||
'trade_count': 0,
|
||
'created': datetime.now().isoformat(),
|
||
'nav_history': [],
|
||
}
|
||
|
||
def save_state(self):
|
||
self.state['updated'] = datetime.now().isoformat()
|
||
with open(self._state_file(), 'w') as f:
|
||
json.dump(self.state, f, ensure_ascii=False, indent=2)
|
||
|
||
def _init_db(self):
|
||
db = sqlite3.connect(self._db_file())
|
||
db.execute('''CREATE TABLE IF NOT EXISTS trades (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
date TEXT NOT NULL,
|
||
code TEXT NOT NULL,
|
||
name TEXT,
|
||
direction TEXT NOT NULL,
|
||
qty INTEGER,
|
||
price REAL,
|
||
amount REAL,
|
||
commission REAL,
|
||
stamp_tax REAL,
|
||
net_amount REAL,
|
||
pnl REAL,
|
||
pnl_pct REAL,
|
||
reason TEXT,
|
||
factor_score REAL,
|
||
created_at TEXT DEFAULT CURRENT_TIMESTAMP)''')
|
||
db.execute('''CREATE TABLE IF NOT EXISTS daily_log (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
date TEXT NOT NULL UNIQUE,
|
||
cash REAL,
|
||
market_value REAL,
|
||
total_nav REAL,
|
||
holdings_count INTEGER,
|
||
return_pct REAL,
|
||
details TEXT,
|
||
created_at TEXT DEFAULT CURRENT_TIMESTAMP)''')
|
||
db.execute('''CREATE TABLE IF NOT EXISTS signal_log (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
date TEXT NOT NULL,
|
||
signal_type TEXT,
|
||
code TEXT,
|
||
name TEXT,
|
||
score REAL,
|
||
details TEXT,
|
||
created_at TEXT DEFAULT CURRENT_TIMESTAMP)''')
|
||
db.execute('''CREATE TABLE IF NOT EXISTS system_log (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
date TEXT NOT NULL,
|
||
level TEXT,
|
||
module TEXT,
|
||
message TEXT,
|
||
created_at TEXT DEFAULT CURRENT_TIMESTAMP)''')
|
||
db.commit()
|
||
db.close()
|
||
|
||
def _log_trade(self, trade):
|
||
"""记录交易到SQLite"""
|
||
db = sqlite3.connect(self._db_file())
|
||
db.execute('''INSERT INTO trades
|
||
(date,code,name,direction,qty,price,amount,commission,stamp_tax,net_amount,pnl,pnl_pct,reason,factor_score)
|
||
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)''',
|
||
(trade['date'], trade['code'], trade.get('name',''),
|
||
trade['direction'], trade.get('qty',0), trade.get('price',0),
|
||
trade.get('amount',0), trade.get('commission',0),
|
||
trade.get('stamp_tax',0), trade.get('net_amount',0),
|
||
trade.get('pnl'), trade.get('pnl_pct'), trade.get('reason',''),
|
||
trade.get('factor_score')))
|
||
db.commit()
|
||
db.close()
|
||
self.state['trade_count'] = self.state.get('trade_count', 0) + 1
|
||
|
||
def _log_daily(self, date, nav_details):
|
||
"""记录每日净值"""
|
||
db = sqlite3.connect(self._db_file())
|
||
db.execute('''INSERT OR REPLACE INTO daily_log
|
||
(date,cash,market_value,total_nav,holdings_count,return_pct,details)
|
||
VALUES (?,?,?,?,?,?,?)''',
|
||
(date, nav_details['cash'], nav_details['market_value'],
|
||
nav_details['total_nav'], nav_details['holdings_count'],
|
||
nav_details['return_pct'], json.dumps(nav_details, ensure_ascii=False)))
|
||
db.commit()
|
||
db.close()
|
||
|
||
def _log_signal(self, date, signal_type, code, name, score, details=''):
|
||
"""记录信号"""
|
||
db = sqlite3.connect(self._db_file())
|
||
db.execute('''INSERT INTO signal_log
|
||
(date,signal_type,code,name,score,details)
|
||
VALUES (?,?,?,?,?,?)''',
|
||
(date, signal_type, code, name, score, details))
|
||
db.commit()
|
||
db.close()
|
||
|
||
def _log_system(self, date, level, module, message):
|
||
"""记录系统事件"""
|
||
db = sqlite3.connect(self._db_file())
|
||
db.execute('''INSERT INTO system_log
|
||
(date,level,module,message)
|
||
VALUES (?,?,?,?)''', (date, level, module, message))
|
||
db.commit()
|
||
db.close()
|
||
|
||
# ─── 核心交易逻辑 ───
|
||
|
||
def buy(self, code, name, price, qty, date, reason='', factor_score=0):
|
||
"""买入"""
|
||
amount = qty * price
|
||
commission = max(amount * self.commission_rate, self.min_commission)
|
||
total_cost = amount + commission
|
||
|
||
if self.state['cash'] < total_cost:
|
||
log.warning(f"资金不足: 需{total_cost:.0f} 有{self.state['cash']:.0f}")
|
||
return False
|
||
|
||
self.state['cash'] -= total_cost
|
||
self.state['positions'][code] = {
|
||
'name': name, 'qty': qty,
|
||
'avg_cost': price, 'current_price': price,
|
||
'buy_date': date, 'factor_score': factor_score,
|
||
}
|
||
|
||
trade = {
|
||
'date': date, 'code': code, 'name': name,
|
||
'direction': 'buy', 'qty': qty, 'price': price,
|
||
'amount': round(amount, 2), 'commission': round(commission, 2),
|
||
'net_amount': round(total_cost, 2),
|
||
'reason': reason, 'factor_score': factor_score,
|
||
}
|
||
self._log_trade(trade)
|
||
self._log_signal(date, 'buy', code, name, factor_score, reason)
|
||
log.info(f"🟢 买入 {name}({code}) {qty}股@¥{price:.2f} = ¥{amount:,.0f} [{reason}]")
|
||
return True
|
||
|
||
def sell(self, code, price, date, reason=''):
|
||
"""卖出"""
|
||
if code not in self.state['positions']:
|
||
return False
|
||
|
||
pos = self.state['positions'][code]
|
||
qty = pos['qty']
|
||
amount = qty * price
|
||
commission = max(amount * self.commission_rate, self.min_commission)
|
||
stamp_tax = amount * self.stamp_tax_rate
|
||
net_revenue = amount - commission - stamp_tax
|
||
|
||
buy_cost = qty * pos['avg_cost']
|
||
pnl = net_revenue - buy_cost
|
||
pnl_pct = (price / pos['avg_cost'] - 1) * 100
|
||
|
||
self.state['cash'] += net_revenue
|
||
del self.state['positions'][code]
|
||
|
||
trade = {
|
||
'date': date, 'code': code, 'name': pos['name'],
|
||
'direction': 'sell', 'qty': qty, 'price': price,
|
||
'amount': round(amount, 2), 'commission': round(commission, 2),
|
||
'stamp_tax': round(stamp_tax, 2), 'net_amount': round(net_revenue, 2),
|
||
'pnl': round(pnl, 2), 'pnl_pct': round(pnl_pct, 2),
|
||
'reason': reason, 'factor_score': pos.get('factor_score', 0),
|
||
}
|
||
self._log_trade(trade)
|
||
self._log_signal(date, 'sell', code, pos['name'], 0, f'{reason} PnL:{pnl:+.0f}({pnl_pct:+.1f}%)')
|
||
|
||
emoji = '✅' if pnl > 0 else '❌'
|
||
log.info(f"{emoji} 卖出 {pos['name']}({code}) {qty}股@¥{price:.2f} "
|
||
f"PnL ¥{pnl:+,.0f}({pnl_pct:+.1f}%) [{reason}]")
|
||
return True
|
||
|
||
def get_nav(self):
|
||
"""计算当前净值"""
|
||
mv = sum(
|
||
pos.get('current_price', pos['avg_cost']) * pos['qty']
|
||
for pos in self.state['positions'].values()
|
||
)
|
||
total = self.state['cash'] + mv
|
||
return {
|
||
'cash': round(self.state['cash'], 2),
|
||
'market_value': round(mv, 2),
|
||
'total_nav': round(total, 2),
|
||
'holdings_count': len(self.state['positions']),
|
||
'return_pct': round((total / self.initial_cash - 1) * 100, 4),
|
||
}
|
||
|
||
def get_today_price(self, code):
|
||
"""获取实时/最新价格"""
|
||
try:
|
||
df = self.dm.pro.daily(ts_code=code,
|
||
start_date=datetime.now().strftime('%Y%m%d'),
|
||
end_date=datetime.now().strftime('%Y%m%d'))
|
||
if df is not None and not df.empty:
|
||
return float(df.iloc[0]['close'])
|
||
except:
|
||
pass
|
||
# 用前一个交易日
|
||
try:
|
||
df = self.dm.pro.daily(ts_code=code, limit=1)
|
||
if df is not None and not df.empty:
|
||
return float(df.iloc[0]['close'])
|
||
except:
|
||
pass
|
||
# 用持仓成本
|
||
if code in self.state['positions']:
|
||
return self.state['positions'][code].get('current_price', 0)
|
||
return 0
|
||
|
||
def check_stop_loss_take_profit(self, date, prices=None):
|
||
"""检查所有持仓的止损止盈"""
|
||
triggered = []
|
||
for code in list(self.state['positions'].keys()):
|
||
pos = self.state['positions'][code]
|
||
p = (prices or {}).get(code, 0)
|
||
if p <= 0:
|
||
p = self.get_today_price(code)
|
||
if p <= 0:
|
||
continue
|
||
|
||
pos['current_price'] = p
|
||
pnl_pct = (p / pos['avg_cost'] - 1)
|
||
|
||
if pnl_pct <= self.stop_loss_pct:
|
||
triggered.append(('stop_loss', code, p, pnl_pct))
|
||
elif pnl_pct >= self.take_profit_pct:
|
||
triggered.append(('take_profit', code, p, pnl_pct))
|
||
|
||
return triggered
|
||
|
||
def scan_buy(self, date, stock_data=None, date_to_idx=None, prices=None):
|
||
"""因子选股扫描,返回推荐买入列表"""
|
||
if stock_data and date_to_idx:
|
||
scored = self.factor.score_all(stock_data, date, date_to_idx, self.pool_size)
|
||
else:
|
||
# 实盘模式:需要先加载数据
|
||
return []
|
||
|
||
nav = self.get_nav()
|
||
max_new = self.max_holdings - len(self.state['positions'])
|
||
if max_new <= 0:
|
||
return []
|
||
|
||
buy_list = []
|
||
for code, score in scored:
|
||
if len(buy_list) >= min(self.top_n_buy, max_new):
|
||
break
|
||
if code in self.state['positions']:
|
||
continue
|
||
p = (prices or {}).get(code, 0)
|
||
if p <= 0 or p < 3:
|
||
continue
|
||
|
||
buy_amt = min(self.state['cash'] * 0.9, nav['total_nav'] * self.max_position_pct)
|
||
qty = int(buy_amt / p / 100) * 100
|
||
if qty < 100:
|
||
continue
|
||
|
||
name = stock_data[code].get('name', code)
|
||
buy_list.append((code, name, p, qty, score))
|
||
|
||
return buy_list
|
||
|
||
# ─── 完整运行一天 ───
|
||
|
||
def run_daily(self, date=None):
|
||
"""
|
||
运行一天的交易逻辑:
|
||
1. 加载数据
|
||
2. 更新持仓价格
|
||
3. 检查止损止盈
|
||
4. 补仓扫描(如果需要)
|
||
5. 记录日志
|
||
"""
|
||
if date is None:
|
||
date = datetime.now().strftime('%Y%m%d')
|
||
|
||
log.info(f"═══════════════════════════════════════")
|
||
log.info(f"📋 开始处理 {date}")
|
||
|
||
# 1. 加载数据
|
||
log.info("📥 加载历史数据...")
|
||
stock_data, trade_dates = self.dm.load_all_history(days=300)
|
||
|
||
if date not in trade_dates:
|
||
log.info(f"{date} 非交易日,跳过")
|
||
return
|
||
|
||
day_idx = trade_dates.index(date)
|
||
if day_idx < 60:
|
||
log.warning("数据不足60天,无法运行")
|
||
return
|
||
|
||
# 构建索引
|
||
code_date_idx = {}
|
||
for code, sd in stock_data.items():
|
||
code_date_idx[code] = {dt: i for i, dt in enumerate(sd['dates'])}
|
||
|
||
# 2. 当天价格
|
||
today_prices = {}
|
||
for code, sd in stock_data.items():
|
||
idx_map = code_date_idx.get(code, {})
|
||
if date in idx_map:
|
||
eidx = idx_map[date]
|
||
if eidx < len(sd['closes']):
|
||
p = sd['closes'][eidx]
|
||
if p > 0:
|
||
today_prices[code] = p
|
||
|
||
# 更新持仓价格
|
||
for code, pos in self.state['positions'].items():
|
||
if code in today_prices:
|
||
pos['current_price'] = today_prices[code]
|
||
|
||
# 3. 止损止盈检查
|
||
triggered = self.check_stop_loss_take_profit(date, today_prices)
|
||
for reason, code, price, pnl_pct in triggered:
|
||
r = '止损' if reason == 'stop_loss' else '止盈'
|
||
self.sell(code, price, date, f'{r}{pnl_pct:+.1f}%')
|
||
|
||
# 4. 补仓扫描
|
||
last_scan = self.state.get('last_scan_idx', -self.scan_interval)
|
||
need_scan = triggered or (day_idx - last_scan) >= self.scan_interval
|
||
|
||
if need_scan and len(self.state['positions']) < self.max_holdings:
|
||
log.info(f"🔍 扫描补仓 (triggered={len(triggered)>0}, interval={day_idx - last_scan}天)")
|
||
buy_list = self.scan_buy(date, stock_data, code_date_idx, today_prices)
|
||
for code, name, price, qty, score in buy_list:
|
||
self.buy(code, name, price, qty, date,
|
||
f'因子{score:.3f}', score)
|
||
self.state['last_scan_date'] = date
|
||
self.state['last_scan_idx'] = day_idx
|
||
|
||
# 5. 记录日志
|
||
nav = self.get_nav()
|
||
self._log_daily(date, nav)
|
||
|
||
# 更新净值历史(保留最近365天)
|
||
self.state.setdefault('nav_history', []).append({
|
||
'date': date, 'nav': nav['total_nav'], 'ret': nav['return_pct']
|
||
})
|
||
if len(self.state.get('nav_history', [])) > 365:
|
||
self.state['nav_history'] = self.state['nav_history'][-365:]
|
||
|
||
self.save_state()
|
||
|
||
# 日终报告
|
||
log.info(f"📊 日终净值: ¥{nav['total_nav']:,.2f} "
|
||
f"({nav['return_pct']:+.2f}%) | "
|
||
f"现金 ¥{nav['cash']:,.0f} | "
|
||
f"持仓 {nav['holdings_count']}只")
|
||
|
||
self._log_system(date, 'INFO', 'engine',
|
||
f"日终净值{nav['total_nav']:.2f} 收益{nav['return_pct']:+.2f}% 持仓{nav['holdings_count']}只")
|
||
|
||
# ─── 报告 ───
|
||
|
||
def get_report(self):
|
||
"""生成完整报告"""
|
||
nav = self.get_nav()
|
||
lines = [
|
||
"=" * 50,
|
||
"📊 纯信号交易系统状态",
|
||
"=" * 50,
|
||
f"日期: {datetime.now().strftime('%Y-%m-%d %H:%M')}",
|
||
f"本金: ¥{self.initial_cash:,.0f}",
|
||
f"净值: ¥{nav['total_nav']:,.2f} ({nav['return_pct']:+.2f}%)",
|
||
f"现金: ¥{nav['cash']:,.0f}",
|
||
f"持仓: {nav['holdings_count']}只",
|
||
f"交易次数: {self.state.get('trade_count', 0)}",
|
||
"",
|
||
"策略参数:",
|
||
f" 止损: {self.stop_loss_pct*100:.0f}% | 止盈: +{self.take_profit_pct*100:.0f}%",
|
||
f" 单只上限: {self.max_position_pct*100:.0f}% | 最多: {self.max_holdings}只",
|
||
f" 扫描间隔: {self.scan_interval}天",
|
||
"",
|
||
]
|
||
|
||
if self.state['positions']:
|
||
lines.append("持仓明细:")
|
||
for code, pos in self.state['positions'].items():
|
||
p = pos.get('current_price', pos['avg_cost'])
|
||
pnl = (p / pos['avg_cost'] - 1) * 100
|
||
lines.append(
|
||
f" {pos['name']}({code}) {pos['qty']}股 "
|
||
f"@¥{pos['avg_cost']:.2f} 现¥{p:.2f} {pnl:+.1f}%"
|
||
)
|
||
|
||
# 最近交易
|
||
db = sqlite3.connect(self._db_file())
|
||
rows = db.execute(
|
||
'SELECT date,code,name,direction,qty,price,pnl,pnl_pct,reason '
|
||
'FROM trades ORDER BY id DESC LIMIT 10'
|
||
).fetchall()
|
||
db.close()
|
||
|
||
if rows:
|
||
lines.append("\n最近交易:")
|
||
for r in rows:
|
||
d, code, name, direction, qty, price, pnl, pnl_pct, reason = r
|
||
if direction == 'buy':
|
||
lines.append(f" 🟢 {d} {name}({code}) {qty}股@¥{price:.2f}")
|
||
else:
|
||
lines.append(
|
||
f" {'✅' if (pnl or 0) > 0 else '❌'} {d} {name}({code}) "
|
||
f"{qty}股@¥{price:.2f} {pnl:+.0f}({pnl_pct:+.1f}%) [{reason}]"
|
||
)
|
||
|
||
return '\n'.join(lines)
|
||
|
||
|
||
# =====================================================================
|
||
# CLI
|
||
# =====================================================================
|
||
if __name__ == '__main__':
|
||
import argparse
|
||
parser = argparse.ArgumentParser(description='纯信号交易引擎')
|
||
parser.add_argument('--action', default='daily',
|
||
choices=['daily', 'status', 'report', 'init', 'backfill'])
|
||
parser.add_argument('--date', default=None, help='指定日期 YYYYMMDD')
|
||
parser.add_argument('--cash', type=float, default=None, help='初始资金')
|
||
args = parser.parse_args()
|
||
|
||
cfg = load_config()
|
||
if args.cash:
|
||
cfg['initial_cash'] = args.cash
|
||
|
||
engine = TradingEngine(cfg)
|
||
|
||
if args.action == 'init':
|
||
log.info("初始化交易系统...")
|
||
engine.save_state()
|
||
log.info(f"初始资金: ¥{engine.initial_cash:,.0f}")
|
||
print(engine.get_report())
|
||
|
||
elif args.action == 'daily':
|
||
engine.run_daily(args.date)
|
||
|
||
elif args.action == 'status' or args.action == 'report':
|
||
print(engine.get_report())
|
||
|
||
elif args.action == 'backfill':
|
||
"""回填历史数据运行(用于初始化持仓)"""
|
||
log.info("回填模式:加载近期数据并建仓...")
|
||
engine.run_daily(args.date)
|