Skip to content

Commit

Permalink
Merge pull request #52 from cowprotocol/add_partner_fee
Browse files Browse the repository at this point in the history
Compute partner fees
  • Loading branch information
harisang authored Sep 6, 2024
2 parents c2b751e + f5dcae2 commit a62813b
Show file tree
Hide file tree
Showing 10 changed files with 292 additions and 80 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ pandas-stubs
types-psycopg2
types-requests
moralis
dune-client
dune-client
pytest
23 changes: 23 additions & 0 deletions src/compute_fees_single_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from hexbytes import HexBytes
from src.fees.compute_fees import compute_all_fees_of_batch
from src.helpers.config import logger


def log_token_data(title: str, data: dict, name: str):
logger.info(title)
for token, value in data.items():
logger.info(f"Token Address: {token}, {name}: {value}")


def main():
protocol_fees, partner_fees, network_fees = compute_all_fees_of_batch(
HexBytes(input("tx hash: "))
)
log_token_data("Protocol Fees:", protocol_fees, "Protocol Fee")
log_token_data("Partner Fees:", partner_fees, "Partner Fee")
log_token_data("Network Fees:", network_fees, "Network Fee")
# e.g. input: 0x980fa3f8ff95c504ba61e054e5c3e50ea36b892f865703b8a665564ac0beb1f4


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
"0x875b6cb035bbd4ac6500fabc6d1e4ca5bdc58a3e2b424ccb5c24cdbebeb009a9"
)

NULL_ADDRESS = Web3.to_checksum_address("0x0000000000000000000000000000000000000000")

REQUEST_TIMEOUT = 5

# Time limit, currently set to 1 full day, after which Coingecko Token List is re-fetched (in seconds)
Expand Down
152 changes: 91 additions & 61 deletions src/fees/compute_fees.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,59 @@
import math
import os
from typing import Any
import requests
import json
from dotenv import load_dotenv
from eth_typing import Address
from eth_typing import ChecksumAddress
from hexbytes import HexBytes
from web3 import Web3

from src.constants import (
REQUEST_TIMEOUT,
)
import requests
from src.constants import REQUEST_TIMEOUT, NULL_ADDRESS

# types for trades


@dataclass
class Trade:
"""Class for"""
"""Class for describing a trade, together with the fees associated with it.
We note that we use the NULL address to indicate that there are no partner fees.
Note that in case an order is placed with the partner fee recipient being the null address,
the partner fee will instead be accounted for as protocol fee and will be withheld by the DAO.
"""

order_uid: HexBytes
sell_amount: int
buy_amount: int
sell_token: HexBytes
buy_token: HexBytes
limit_sell_amount: int
limit_buy_amount: int
kind: str
sell_token_clearing_price: int
buy_token_clearing_price: int
fee_policies: list["FeePolicy"]
def __init__(
self,
order_uid: HexBytes,
sell_amount: int,
buy_amount: int,
sell_token: HexBytes,
buy_token: HexBytes,
limit_sell_amount: int,
limit_buy_amount: int,
kind: str,
sell_token_clearing_price: int,
buy_token_clearing_price: int,
fee_policies: list["FeePolicy"],
partner_fee_recipient: ChecksumAddress,
):
self.order_uid = order_uid
self.sell_amount = sell_amount
self.buy_amount = buy_amount
self.sell_token = sell_token
self.buy_token = buy_token
self.limit_sell_amount = limit_sell_amount
self.limit_buy_amount = limit_buy_amount
self.kind = kind
self.sell_token_clearing_price = sell_token_clearing_price
self.buy_token_clearing_price = buy_token_clearing_price
self.fee_policies = fee_policies
self.partner_fee_recipient = partner_fee_recipient # if there is no partner, then its value is set to the null address

total_protocol_fee, partner_fee, network_fee = self.compute_all_fees()
self.total_protocol_fee = total_protocol_fee
self.partner_fee = partner_fee
self.network_fee = network_fee
return

def volume(self) -> int:
"""Compute volume of a trade in the surplus token"""
Expand Down Expand Up @@ -62,20 +88,28 @@ def surplus(self) -> int:
return current_limit_sell_amount - self.sell_amount
raise ValueError(f"Order kind {self.kind} is invalid.")

def raw_surplus(self) -> int:
"""Compute raw surplus of a trade in the surplus token
First, the application of protocol fees is reversed. Then, surplus of the resulting trade
is computed."""
def compute_all_fees(self) -> tuple[int, int, int]:
raw_trade = deepcopy(self)
for fee_policy in reversed(self.fee_policies):
partner_fee = 0
for i, fee_policy in enumerate(reversed(self.fee_policies)):
raw_trade = fee_policy.reverse_protocol_fee(raw_trade)
return raw_trade.surplus()
## we assume that partner fee is the last to be applied
if i == 0 and self.partner_fee_recipient != NULL_ADDRESS:
partner_fee = raw_trade.surplus() - self.surplus()
total_protocol_fee = raw_trade.surplus() - self.surplus()

def protocol_fee(self):
"""Compute protocol fees of a trade in the surplus token
Protocol fees are computed as the difference of raw surplus and surplus."""

return self.raw_surplus() - self.surplus()
surplus_fee = self.compute_surplus_fee() # in the surplus token
network_fee_in_surplus_token = surplus_fee - total_protocol_fee
if self.kind == "sell":
network_fee = int(
network_fee_in_surplus_token
* Fraction(
self.buy_token_clearing_price, self.sell_token_clearing_price
)
)
else:
network_fee = network_fee_in_surplus_token
return total_protocol_fee, partner_fee, network_fee

def surplus_token(self) -> HexBytes:
"""Returns the surplus token"""
Expand Down Expand Up @@ -336,6 +370,14 @@ def get_all_data(self, tx_hash: HexBytes) -> SettlementData:
buy_token_clearing_price = clearing_prices[buy_token]
fee_policies = self.parse_fee_policies(trade_data["feePolicies"])

app_data = json.loads(order_data["fullAppData"])
if "partnerFee" in app_data["metadata"].keys():
partner_fee_recipient = Web3.to_checksum_address(
HexBytes(app_data["metadata"]["partnerFee"]["recipient"])
)
else:
partner_fee_recipient = NULL_ADDRESS

trade = Trade(
order_uid=uid,
sell_amount=executed_sell_amount,
Expand All @@ -348,6 +390,7 @@ def get_all_data(self, tx_hash: HexBytes) -> SettlementData:
sell_token_clearing_price=sell_token_clearing_price,
buy_token_clearing_price=buy_token_clearing_price,
fee_policies=fee_policies,
partner_fee_recipient=partner_fee_recipient,
)
trades.append(trade)

Expand Down Expand Up @@ -436,48 +479,35 @@ def parse_fee_policies(
return fee_policies


# computing fees
def compute_fee_imbalances(
settlement_data: SettlementData,
) -> tuple[dict[str, tuple[str, int]], dict[str, tuple[str, int]]]:
# function that computes all fees of all orders in a batch
# Note that currently it is NOT working for CoW AMMs as they are not indexed.
def compute_all_fees_of_batch(
tx_hash: HexBytes,
) -> tuple[
dict[str, tuple[str, int]],
dict[str, tuple[str, int, str]],
dict[str, tuple[str, int]],
]:
orderbook_api = OrderbookFetcher()
settlement_data = orderbook_api.get_all_data(tx_hash)
protocol_fees: dict[str, tuple[str, int]] = {}
network_fees: dict[str, tuple[str, int]] = {}
partner_fees: dict[str, tuple[str, int, str]] = {}
for trade in settlement_data.trades:
# protocol fees
protocol_fee_amount = trade.protocol_fee()
protocol_fee_amount = trade.total_protocol_fee - trade.partner_fee
protocol_fee_token = trade.surplus_token()
protocol_fees[trade.order_uid.to_0x_hex()] = (
protocol_fee_token.to_0x_hex(),
protocol_fee_amount,
)
# network fees
surplus_fee = trade.compute_surplus_fee() # in the surplus token
network_fee = surplus_fee - protocol_fee_amount
if trade.kind == "sell":
network_fee_sell = int(
network_fee
* Fraction(
trade.buy_token_clearing_price, trade.sell_token_clearing_price
)
)
else:
network_fee_sell = network_fee

partner_fees[trade.order_uid.to_0x_hex()] = (
protocol_fee_token.to_0x_hex(),
trade.partner_fee,
trade.partner_fee_recipient,
)
network_fees[trade.order_uid.to_0x_hex()] = (
trade.sell_token.to_0x_hex(),
network_fee_sell,
trade.network_fee,
)

return protocol_fees, network_fees


# combined function


def batch_fee_imbalances(
tx_hash: HexBytes,
) -> tuple[dict[str, tuple[str, int]], dict[str, tuple[str, int]]]:
orderbook_api = OrderbookFetcher()
settlement_data = orderbook_api.get_all_data(tx_hash)
protocol_fees, network_fees = compute_fee_imbalances(settlement_data)
return protocol_fees, network_fees
return protocol_fees, partner_fees, network_fees
8 changes: 7 additions & 1 deletion src/helpers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,25 @@ def write_prices(

def write_fees(
self,
chain_name: str,
auction_id: int,
block_number: int,
tx_hash: str,
order_uid: str,
token_address: str,
fee_amount: float,
fee_type: str,
recipient: str,
):
"""Function attempts to write price data to the table."""
tx_hash_bytes = bytes.fromhex(tx_hash[2:])
token_address_bytes = bytes.fromhex(token_address[2:])
order_uid_bytes = bytes.fromhex(order_uid[2:])

query = read_sql_file("src/sql/insert_fee.sql")
final_recipient = None
if recipient != "":
final_recipient = bytes.fromhex(recipient[2:])

self.execute_and_commit(
query,
{
Expand All @@ -114,5 +119,6 @@ def write_fees(
"token_address": token_address_bytes,
"fee_amount": fee_amount,
"fee_type": fee_type,
"recipient": final_recipient,
},
)
6 changes: 3 additions & 3 deletions src/sql/insert_fee.sql
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
INSERT INTO fees_new (
chain_name, auction_id, block_number, tx_hash, order_uid, token_address, fee_amount,fee_type
) VALUES ( :chain_name, :auction_id, :block_number, :tx_hash, :order_uid, :token_address, :fee_amount, :fee_type
INSERT INTO fees_per_trade (
chain_name, auction_id, block_number, tx_hash, order_uid, token_address, fee_amount, fee_type, fee_recipient
) VALUES ( :chain_name, :auction_id, :block_number, :tx_hash, :order_uid, :token_address, :fee_amount, :fee_type, :fee_recipient
);
6 changes: 4 additions & 2 deletions src/test_single_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from web3 import Web3
from src.imbalances_script import RawTokenImbalances
from src.price_providers.price_feed import PriceFeed
from src.fees.compute_fees import batch_fee_imbalances
from src.fees.compute_fees import compute_all_fees_of_batch
from src.transaction_processor import calculate_slippage
from src.helpers.config import get_web3_instance, logger
from contracts.erc20_abi import erc20_abi
Expand All @@ -26,7 +26,9 @@ def __init__(self):

def compute_data(self, tx_hash: str):
token_imbalances = self.imbalances.compute_imbalances(tx_hash)
protocol_fees, network_fees = batch_fee_imbalances(HexBytes(tx_hash))
protocol_fees, partner_fees, network_fees = compute_all_fees_of_batch(
HexBytes(tx_hash)
)
slippage = calculate_slippage(token_imbalances, protocol_fees, network_fees)
eth_slippage = self.calculate_slippage_in_eth(slippage, tx_hash)

Expand Down
Loading

0 comments on commit a62813b

Please sign in to comment.