Skip to content
Closed
195 changes: 195 additions & 0 deletions build/generate_private_tests.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
const { match } = require('assert');
const fs = require('fs');

const SOURSE_FILE_NAME = './kucoin/client.py'
const TARGET_FILE_NAME = './tests/test_private_requests_generated.py'
const METHOD_DEFINITION_MATCH = /def\s(\w+)/
const METHOD_CALL_MATCH = /self\.(\w+)\(/
const REQUEST_METHODS = ['_get', '_post', '_put', '_delete']
const METHOD_NAME_MATCH = /(\w+)\(/
const ARGUMENTS_MATCH = /\((.*)\)/
const REST_API_URL = 'https://api.kucoin.com'
const REST_FUTURES_API_URL = 'https://api-futures.kucoin.com'
const API_VERSIONS = {
'API_VERSION': 'v1',
'API_VERSION2': 'v2',
'API_VERSION3': 'v3'
}

const MANDATORY_ARGS = [
'sub_user_id',
'include_base_ammount',
'sub_name',
'passphrase',
'remark',
'api_key',
'account_id',
'account_type',
'currency',
'type',
'client_oid',
'amount',
'from_account_type',
'to_account_type',
'pay_account_type',
'withdrawal_id',
'symbols',
'symbol',
'side',
'order_list',
'order_id',
'cancel_size',
'timeout',
'stop_price',
'size',
'price',
'limit_price',
'orders_data',
'trade_type',
'interest_rate',
'purchase_order_no'
]

function getPrivateMethodsArgumentsAndRequests (data) {
const lines = data.split ('\n')
lines.push('\n')
const methods = {}
let metodName = ''
let methodArgs = []
let lastMethodDefinition = ''
for (let i = 0; i < lines.length; i++) {
let line = lines[i]
const methodDefinition = line.match (METHOD_DEFINITION_MATCH)
const methodCall = line.match (METHOD_CALL_MATCH)
if (methodDefinition) { // if line is method definition
while (!line.includes (':')) {
i++
line += lines[i].trim()
}
lastMethodDefinition = line.replace ('def ', '')
.replaceAll (' ', '')
.replaceAll (',)', ')')
.replace (':', '')
continue // we need to check if this is private method
} else if (methodCall) {
let name = methodCall[1]
if (!REQUEST_METHODS.includes (name)) { // if this is not request method just skip
continue
}
if (line.endsWith ('(')) { // if request method is called not in one line
i++
let nextLine = lines[i].trim()
while (nextLine !== ')') {
line += nextLine
i++
nextLine = lines[i].trim()
}
line += nextLine
}
if (line.match (/[^,]+, True/)) { // if request method is called with second argument True
if (line.indexOf ('path') !== -1) {
continue // skip methods with path argument
}
[ metodName, methodArgs ] = getNameAndArgs (lastMethodDefinition)
methods[metodName] = {
args: methodArgs,
request: getParamsFromRequestCall (line.trim ().replace ('return self._', '').replaceAll ("'", '"'))
}
}
}
}
return methods
}

function getNameAndArgs (line) {
let name = line.match (METHOD_NAME_MATCH)[1]
let args = line.match (ARGUMENTS_MATCH)[1].trim().split (',').filter (arg => (arg !== 'self' && arg !== '**params'))
return [ name, args ]
}

function getParamsFromRequestCall (line) {
const matchMethodAndEndpointPattern = /(\w+)\("([^"]*)"/
const matchFuturePattern = /is_futures=(\w+)/
const matchVersionPattern = /version=self.(\w+)/
const matchFormatPattern = /.format\((\w+)\)/
const matchMethodAndEndpoint = line.match (matchMethodAndEndpointPattern)
const matchFuture = line.match (matchFuturePattern)
const matchVersion = line.match (matchVersionPattern)
const isFutures = (matchFuture && matchFuture[1] === 'True') ? true : false
const baseUrl = isFutures ? REST_FUTURES_API_URL : REST_API_URL
const version = matchVersion ? API_VERSIONS[matchVersion[1]] : API_VERSIONS.API_VERSION
let endpoint = matchMethodAndEndpoint[2]
const matchFormat = line.match (matchFormatPattern)
if (matchFormat) {
endpoint = endpoint.replace (matchFormat[0], '').replace ('{}', '{' + matchFormat[1] + '}')
}
const url = baseUrl + '/api/' + version + '/' + endpoint
return {
full: line,
url: url,
method: matchMethodAndEndpoint[1],
endpoint: endpoint,
isFutures: isFutures,
}
}

function generateTests (methods) {
const tests = [
'import requests_mock',
'import pytest',
'from aioresponses import aioresponses'
]
const methodNames = Object.keys (methods)
for (let methodName of methodNames) {
const method = methods[methodName]
const mandatoryArgs = generateMandatoryArgs (method)
let functionArgs = ''
for (let arg of mandatoryArgs) {
functionArgs += '"' + arg + '", '
}
functionArgs = functionArgs.slice (0, -2)
const request = method.request
let url = request.url
const paramInParth = url.match (/{(\w+)}/)
if (paramInParth) {
url = url.replace (paramInParth[0], paramInParth[1])
}

const test = [
'\n',
'def test_' + methodName + '(client):',
' with requests_mock.mock() as m:',
' m.' + request.method + '("' + url + '")',
' client.' + methodName + '('+ functionArgs + ')',
' assert m.last_request.url == "' + url + '"'

]
tests.push (...test)
}
return tests.join ('\n')
}

function generateMandatoryArgs (method) {
const args = method.args
return args.filter (arg => arg.indexOf ('=') === -1)
}

function main () {
fs.readFile (SOURSE_FILE_NAME, 'utf8', (err, data) => {
if (err) {
console.error (err)
return
}
const privateMethodsArgumentsAndRequests = getPrivateMethodsArgumentsAndRequests (data)
const tests = generateTests (privateMethodsArgumentsAndRequests)
fs.writeFile (TARGET_FILE_NAME, tests, (err) => {
if (err) {
console.error (err)
return
}
console.log (TARGET_FILE_NAME + ' file is generated')
})
})
}

main ()
35 changes: 15 additions & 20 deletions kucoin/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,12 +469,8 @@ async def get_symbol(self, symbol=None, **params):

"""

data = {}
if symbol:
data["symbol"] = symbol

return await self._get(
"symbol", False, api_version=self.API_VERSION2, data=dict(data, **params)
"symbols/{}".format(symbol), False, api_version=self.API_VERSION2, **params
)

async def get_ticker(self, symbol, **params):
Expand Down Expand Up @@ -763,7 +759,7 @@ async def get_trade_histories(self, symbol, **params):

return await self._get("market/histories", False, data=dict(data, **params))

async def get_kline_data(self, symbol, kline_type="5min", start=None, end=None, **params):
async def get_klines(self, symbol, kline_type="5min", start=None, end=None, **params):
"""Get kline data

https://www.kucoin.com/docs/rest/spot-trading/market-data/get-klines
Expand All @@ -773,7 +769,7 @@ async def get_kline_data(self, symbol, kline_type="5min", start=None, end=None,

:param symbol: Name of symbol e.g. KCS-BTC
:type symbol: string
:param kline_type: type of symbol, type of candlestick patterns: 1min, 3min, 5min, 15min, 30min, 1hour, 2hour,
:param kline_type: type of candlestick patterns: 1min, 3min, 5min, 15min, 30min, 1hour, 2hour,
4hour, 6hour, 8hour, 12hour, 1day, 1week
:type kline_type: string
:param start: Start time as unix timestamp (optional) default start of day in UTC
Expand All @@ -783,7 +779,7 @@ async def get_kline_data(self, symbol, kline_type="5min", start=None, end=None,

.. code:: python

klines = client.get_kline_data('KCS-BTC', '5min', 1507479171, 1510278278)
klines = client.get_klines('KCS-BTC', '5min', 1507479171, 1510278278)

:returns: ApiResponse

Expand Down Expand Up @@ -1277,7 +1273,7 @@ async def futures_get_trade_histories(self, symbol, **params):
)

async def futures_get_klines(
self, symbol, kline_type="5min", start=None, end=None, **params
self, symbol, kline_type=5, start=None, end=None, **params
):
"""Get kline data

Expand All @@ -1287,17 +1283,16 @@ async def futures_get_klines(

:param symbol: Name of symbol e.g. XBTUSDTM
:type symbol: string
:param kline_type: type of symbol, type of candlestick patterns: 1min, 3min, 5min, 15min, 30min, 1hour, 2hour,
4hour, 6hour, 8hour, 12hour, 1day, 1week
:type kline_type: string
:param kline_type: type of candlestick in minutes: 1, 5, 50 etc.
:type kline_type: int
:param start: Start time as unix timestamp (optional) default start of day in UTC
:type start: int
:param end: End time as unix timestamp (optional) default now in UTC
:type end: int

.. code:: python

klines = client.futures_get_klines('XBTUSDTM', '5min', 1507479171, 1510278278)
klines = client.futures_get_klines('XBTUSDTM', 5, 1507479171, 1510278278)

:returns: ApiResponse

Expand Down Expand Up @@ -8919,7 +8914,7 @@ async def get_fills(
if limit:
data["pageSize"] = limit

return await self._get("fills", False, data=dict(data, **params))
return await self._get("fills", True, data=dict(data, **params))

async def get_recent_fills(self, **params):
"""Get a list of recent fills.
Expand Down Expand Up @@ -9232,7 +9227,7 @@ async def futures_get_fills(
if limit:
data["pageSize"] = limit

return await self._get("fills", False, is_futures=True, data=dict(data, **params))
return await self._get("fills", True, is_futures=True, data=dict(data, **params))

async def futures_get_recent_fills(self, symbol=None, **params):
"""Get a list of recent futures fills.
Expand Down Expand Up @@ -9287,7 +9282,7 @@ async def futures_get_recent_fills(self, symbol=None, **params):
data["symbol"] = symbol

return await self._get(
"recentFills", False, is_futures=True, data=dict(data, **params)
"recentFills", True, is_futures=True, data=dict(data, **params)
)

async def futures_get_active_order_value(self, symbol, **params):
Expand Down Expand Up @@ -9324,7 +9319,7 @@ async def futures_get_active_order_value(self, symbol, **params):
data = {"symbol": symbol}

return await self._get(
"openOrderStatistics", False, is_futures=True, data=dict(data, **params)
"openOrderStatistics", True, is_futures=True, data=dict(data, **params)
)

# Margin Info Endpoints
Expand Down Expand Up @@ -9378,7 +9373,7 @@ async def margin_get_leverage_token_info(self, currency=None, **params):
async def margin_get_all_trading_pairs_mark_prices(self, **params):
"""Get a list of trading pairs and their mark prices

https://www.kucoin.com/docs/rest/margin-trading/margin-info/get-all-trading-pairs-mark-price
https://www.kucoin.com/docs/rest/margin-trading/margin-info/get-all-margin-trading-pairs-mark-prices

.. code:: python

Expand Down Expand Up @@ -10295,7 +10290,7 @@ async def margin_lending_get_currency_info(self, currency=None, **params):

return await self._get(
"project/list",
False,
True,
api_version=self.API_VERSION3,
data=dict(data, **params),
)
Expand Down Expand Up @@ -10341,7 +10336,7 @@ async def margin_lending_get_interest_rate(self, currency, **params):

return await self._get(
"project/marketInterestRatet",
False,
True,
api_version=self.API_VERSION3,
data=dict(data, **params),
)
Expand Down
Loading
Loading