Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions scripts/data_collector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import copy
import importlib
import json
import time
import bisect
import pickle
Expand Down Expand Up @@ -200,7 +201,7 @@ def _get_symbol():
"""
# url = "http://99.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&po=1&np=1&fs=m:0+t:6,m:0+t:80,m:1+t:2,m:1+t:23,m:0+t:81+s:2048&fields=f12"

base_url = "http://99.push2.eastmoney.com/api/qt/clist/get"
base_url = "https://99.push2.eastmoney.com/api/qt/clist/get"
params = {
"pn": 1, # page number
"pz": 100, # page size, default to 100
Expand All @@ -216,7 +217,7 @@ def _get_symbol():
while True:
params["pn"] = page
try:
resp = requests.get(base_url, params=params, timeout=None)
resp = requests.get(base_url, params=params, timeout=60)
resp.raise_for_status()
data = resp.json()

Expand Down Expand Up @@ -276,10 +277,10 @@ def _get_symbol():
symbol_cache_path.parent.mkdir(parents=True, exist_ok=True)
if symbol_cache_path.exists():
with symbol_cache_path.open("rb") as fp:
cache_symbols = restricted_pickle_load(fp)
cache_symbols = json.load(fp)
symbols |= cache_symbols
with symbol_cache_path.open("wb") as fp:
pickle.dump(symbols, fp)
json.dump(list(symbols), fp)

_HS_SYMBOLS = sorted(list(symbols))

Expand Down Expand Up @@ -334,7 +335,7 @@ def _get_nyse():
"maxResultsPerPage": 10000,
"filterToken": "",
}
resp = requests.post(url, json=_parms, timeout=None)
resp = requests.post(url, json=_parms, timeout=60)
if resp.status_code != 200:
raise ValueError("request error")

Expand Down Expand Up @@ -424,7 +425,7 @@ def _get_ibovespa():

# Request
agent = {"User-Agent": "Mozilla/5.0"}
page = requests.get(url, headers=agent, timeout=None)
page = requests.get(url, headers=agent, timeout=60)

# BeautifulSoup
soup = BeautifulSoup(page.content, "html.parser")
Expand Down Expand Up @@ -470,8 +471,8 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:

@deco_retry
def _get_eastmoney():
url = "http://fund.eastmoney.com/js/fundcode_search.js"
resp = requests.get(url, timeout=None)
url = "https://fund.eastmoney.com/js/fundcode_search.js"
resp = requests.get(url, timeout=60)
if resp.status_code != 200:
raise ValueError("request error")
try:
Expand Down Expand Up @@ -655,7 +656,11 @@ def get_instruments(
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies

"""
_cur_module = importlib.import_module("data_collector.{}.collector".format(market_index))
_VALID_MARKET_INDICES = {"baostock_5min", "br_index", "cn_index", "us_index", "crypto", "fund", "pit", "yahoo"}
if market_index not in _VALID_MARKET_INDICES:
raise ValueError(f"Invalid market_index value: {market_index!r}")
_module_name = "data_collector.{}.collector".format(market_index) # nosemgrep: python.lang.security.audit.non-literal-import.non-literal-import
_cur_module = importlib.import_module(_module_name) # nosemgrep: python.lang.security.audit.non-literal-import.non-literal-import
obj = getattr(_cur_module, f"{index_name.upper()}Index")(
qlib_dir=qlib_dir, index_name=index_name, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep
)
Expand Down Expand Up @@ -835,4 +840,5 @@ def calc_paused_num(df: pd.DataFrame, _date_field_name, _symbol_field_name):


if __name__ == "__main__":
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM
if len(get_hs_stock_symbols()) < MINIMUM_SYMBOLS_NUM:
raise AssertionError(f"Expected at least {MINIMUM_SYMBOLS_NUM} symbols, got fewer.")