Repository URL to install this package:
|
Version:
1.19.2+sf.0 ▾
|
#!/usr/bin/python3
import gc
import threading
import time
import weakref
from hashlib import sha1
from pathlib import Path
from sqlite3 import OperationalError
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
from hexbytes import HexBytes
from web3.types import BlockData
from brownie._config import CONFIG, _get_data_folder
from brownie._singleton import _Singleton
from brownie.convert import Wei
from brownie.exceptions import BrownieEnvironmentError, CompilerError
from brownie.network import rpc
from brownie.project.build import DEPLOYMENT_KEYS
from brownie.utils.sql import Cursor
from .transaction import TransactionReceipt
from .web3 import _resolve_address, web3
_contract_map: Dict = {}
_revert_refs: List = []
cur = Cursor(_get_data_folder().joinpath("deployments.db"))
cur.execute("CREATE TABLE IF NOT EXISTS sources (hash PRIMARY KEY, source)")
class TxHistory(metaclass=_Singleton):
"""List-like singleton container that contains TransactionReceipt objects.
Whenever a transaction is broadcast, the TransactionReceipt is automatically
added to this container."""
def __init__(self) -> None:
self._list: List = []
self.gas_profile: Dict = {}
_revert_register(self)
def __repr__(self) -> str:
if CONFIG.argv["cli"] == "console":
return str(self._list)
return super().__repr__()
def __getattribute__(self, name: str) -> Any:
# filter dropped transactions prior to attribute access
items = super().__getattribute__("_list")
items = [i for i in items if i.status != -2]
setattr(self, "_list", items)
return super().__getattribute__(name)
def __bool__(self) -> bool:
return bool(self._list)
def __contains__(self, item: Any) -> bool:
return item in self._list
def __iter__(self) -> Iterator:
return iter(self._list)
def __getitem__(self, key: int) -> TransactionReceipt:
return self._list[key]
def __len__(self) -> int:
return len(self._list)
def _reset(self) -> None:
self._list.clear()
def _revert(self, height: int) -> None:
self._list = [i for i in self._list if i.block_number <= height]
def _add_tx(self, tx: TransactionReceipt) -> None:
if tx not in self._list:
self._list.append(tx)
def clear(self, only_confirmed: bool = False) -> None:
"""
Clear the list.
Arguments
---------
only_confirmed : bool, optional
If True, transactions which are still marked as pending will not be removed.
"""
if only_confirmed:
self._list = [i for i in self._list if i.status == -1]
else:
self._list.clear()
def copy(self) -> List:
"""Returns a shallow copy of the object as a list"""
return self._list.copy()
def filter(
self, key: Optional[Callable] = None, **kwargs: Optional[Any]
) -> List[TransactionReceipt]:
"""
Return a filtered list of transactions.
Arguments
---------
key : Callable, optional
An optional function to filter with. It should expect one agument and return
True or False.
Keyword Arguments
-----------------
**kwargs : Any
Names and expected values for TransactionReceipt attributes.
Returns
-------
List
A filtered list of TransactionReceipt objects.
"""
result = [i for i in self._list if all(getattr(i, k) == v for k, v in kwargs.items())]
if key is None:
return result
return [i for i in result if key(i)]
def wait(self, key: Optional[Callable] = None, **kwargs: Optional[Any]) -> None:
"""
Wait for pending transactions to confirm.
This method iterates over a list of transactions generated by `TxHistory.filter`,
waiting until each transaction has confirmed. If no arguments are given, all
transactions within the container are used.
Arguments
---------
key : Callable, optional
An optional function to filter with. It should expect one agument and return
True or False.
Keyword Arguments
-----------------
**kwargs : Any
Names and expected values for TransactionReceipt attributes.
"""
while True:
pending = next(iter(self.filter(key, status=-1, **kwargs)), None)
if pending is None:
return
pending._confirmed.wait()
def from_sender(self, account: str) -> List:
"""Returns a list of transactions where the sender is account"""
return [i for i in self._list if i.sender == account]
def to_receiver(self, account: str) -> List:
"""Returns a list of transactions where the receiver is account"""
return [i for i in self._list if i.receiver == account]
def of_address(self, account: str) -> List:
"""Returns a list of transactions where account is the sender or receiver"""
return [i for i in self._list if i.receiver == account or i.sender == account]
def _gas(self, fn_name: str, gas_used: int, is_success: bool) -> None:
gas = self.gas_profile.setdefault(fn_name, {})
if not gas:
gas.update(
avg=gas_used, high=gas_used, low=gas_used, count=1, count_success=0, avg_success=0
)
if is_success:
gas["count_success"] = 1
gas["avg_success"] = gas_used
return
gas.update(
avg=(gas["avg"] * gas["count"] + gas_used) // (gas["count"] + 1),
high=max(gas["high"], gas_used),
low=min(gas["low"], gas_used),
)
gas["count"] += 1
if is_success:
count = gas["count_success"]
gas["count_success"] += 1
if not gas["avg_success"]:
gas["avg_success"] = gas_used
else:
avg = gas["avg_success"]
gas["avg_success"] = (avg * count + gas_used) // (count + 1)
class Chain(metaclass=_Singleton):
"""
List-like singleton used to access block data, and perform actions such as
snapshotting, mining, and chain rewinds.
"""
def __init__(self) -> None:
self._time_offset: int = 0
self._snapshot_id: Optional[int] = None
self._reset_id: Optional[int] = None
self._current_id: Optional[int] = None
self._undo_lock = threading.Lock()
self._undo_buffer: List = []
self._redo_buffer: List = []
self._chainid: Optional[int] = None
self._block_gas_time: int = -1
self._block_gas_limit: int = 0
def __repr__(self) -> str:
try:
return f"<Chain object (chainid={self.id}, height={self.height})>"
except Exception:
return "<Chain object (disconnected)>"
def __len__(self) -> int:
"""
Return the current number of blocks.
"""
return web3.eth.block_number + 1
def __getitem__(self, block_number: int) -> BlockData:
"""
Return information about a block by block number.
Arguments
---------
block_number : int
Integer of a block number. If the value is negative, the block returned
is relative to the most recently mined block, e.g. `chain[-1]` returns
the most recent block.
Returns
-------
BlockData
web3 block data object
"""
if not isinstance(block_number, int):
raise TypeError("Block height must be given as an integer")
if block_number < 0:
block_number = web3.eth.block_number + 1 + block_number
block = web3.eth.get_block(block_number)
if block["timestamp"] > self._block_gas_time:
self._block_gas_limit = block["gasLimit"]
self._block_gas_time = block["timestamp"]
return block
def __iter__(self) -> Iterator:
return iter(web3.eth.get_block(i) for i in range(web3.eth.block_number + 1))
def new_blocks(self, height_buffer: int = 0, poll_interval: int = 5) -> Iterator:
"""
Generator for iterating over new blocks.
Arguments
---------
height_buffer : int, optional
The number of blocks behind "latest" to return. A higher value means
more delayed results but less likelihood of uncles.
poll_interval : int, optional
Maximum interval between querying for a new block, if the height has
not changed. Set this lower to detect uncles more frequently.
"""
if height_buffer < 0:
raise ValueError("Buffer cannot be negative")
last_block = None
last_height = 0
last_poll = 0.0
while True:
if last_poll + poll_interval < time.time() or last_height != web3.eth.block_number:
last_height = web3.eth.block_number
block = web3.eth.get_block(last_height - height_buffer)
last_poll = time.time()
if block != last_block:
last_block = block
yield last_block
else:
time.sleep(1)
@property
def height(self) -> int:
return web3.eth.block_number
@property
def id(self) -> int:
if self._chainid is None:
self._chainid = web3.eth.chain_id
return self._chainid
@property
def block_gas_limit(self) -> Wei:
if time.time() > self._block_gas_time + 3600:
block = web3.eth.get_block("latest")
self._block_gas_limit = block["gasLimit"]
self._block_gas_time = block["timestamp"]
return Wei(self._block_gas_limit)
@property
def base_fee(self) -> Wei:
block = web3.eth.get_block("latest")
return Wei(block.baseFeePerGas)
@property
def priority_fee(self) -> Wei:
return Wei(web3.eth.max_priority_fee)
def _revert(self, id_: int) -> int:
rpc_client = rpc.Rpc()
if web3.isConnected() and not web3.eth.block_number and not self._time_offset:
_notify_registry(0)
return rpc_client.snapshot()
rpc_client.revert(id_)
id_ = rpc_client.snapshot()
try:
self.sleep(0)
except NotImplementedError:
pass
_notify_registry()
return id_
def _add_to_undo_buffer(self, tx: Any, fn: Any, args: Tuple, kwargs: Dict) -> None:
with self._undo_lock:
tx._confirmed.wait()
self._undo_buffer.append((self._current_id, fn, args, kwargs))
if self._redo_buffer and (fn, args, kwargs) == self._redo_buffer[-1]:
self._redo_buffer.pop()
else:
self._redo_buffer.clear()
self._current_id = rpc.Rpc().snapshot()
# ensure the local time offset is correct, in case it was modified by the transaction
self.sleep(0)
def _network_connected(self) -> None:
self._reset_id = None
try:
self.reset()
except NotImplementedError:
# required for geth dev
_notify_registry(0)
def _network_disconnected(self) -> None:
self._undo_buffer.clear()
self._redo_buffer.clear()
self._snapshot_id = None
self._reset_id = None
self._current_id = None
self._chainid = None
_notify_registry(0)
def get_transaction(self, txid: Union[str, bytes]) -> TransactionReceipt:
"""
Return a TransactionReceipt object for the given transaction hash.
"""
if not isinstance(txid, str):
txid = HexBytes(txid).hex()
tx = next((i for i in TxHistory() if i.txid == txid), None)
return tx or TransactionReceipt(txid, silent=True, required_confs=0)
def time(self) -> int:
"""Return the current epoch time from the test RPC as an int"""
return int(time.time() + self._time_offset)
def sleep(self, seconds: int) -> None:
"""
Increase the time within the test RPC.
Arguments
---------
seconds : int
Number of seconds to increase the time by
"""
if not isinstance(seconds, int):
raise TypeError("seconds must be an integer value")
self._time_offset = int(rpc.Rpc().sleep(seconds))
if seconds:
self._redo_buffer.clear()
self._current_id = rpc.Rpc().snapshot()
def mine(self, blocks: int = 1, timestamp: int = None, timedelta: int = None) -> int:
"""
Increase the block height within the test RPC.
Arguments
---------
blocks : int
Number of new blocks to be mined
timestamp : int
Timestamp of the final block being mined. If multiple blocks
are mined, they will be placed at equal intervals starting
at `chain.time()` and ending at `timestamp`.
timedelta : int
Timedelta for the final block to be mined. If given, the final
block will have a timestamp of `chain.time() + timedelta`
Returns
-------
int
Current block height
"""
if not isinstance(blocks, int):
raise TypeError("`blocks` must be an integer value")
if timedelta is not None and timestamp is not None:
raise ValueError("Cannot use both `timestamp` and `timedelta`")
if timedelta is not None:
timestamp = self.time() + timedelta
if timestamp is None:
params: List = [[] for i in range(blocks)]
elif blocks == 1:
params = [[timestamp]]
else:
now = self.time()
duration = (timestamp - now) / (blocks - 1)
params = [[round(now + duration * i)] for i in range(blocks)]
for i in range(blocks):
rpc.Rpc().mine(*params[i])
if timestamp is not None:
self.sleep(0)
self._redo_buffer.clear()
self._current_id = rpc.Rpc().snapshot()
return web3.eth.block_number
def snapshot(self) -> None:
"""
Take a snapshot of the current state of the EVM.
This action clears the undo buffer.
"""
self._undo_buffer.clear()
self._redo_buffer.clear()
self._snapshot_id = self._current_id = rpc.Rpc().snapshot()
def revert(self) -> int:
"""
Revert the EVM to the most recently taken snapshot.
This action clears the undo buffer.
Returns
-------
int
Current block height
"""
if self._snapshot_id is None:
raise ValueError("No snapshot set")
self._undo_buffer.clear()
self._redo_buffer.clear()
self._snapshot_id = self._current_id = self._revert(self._snapshot_id)
return web3.eth.block_number
def reset(self) -> int:
"""
Revert the EVM to the initial state when loaded.
This action clears the undo buffer.
Returns
-------
int
Current block height
"""
self._snapshot_id = None
self._undo_buffer.clear()
self._redo_buffer.clear()
if self._reset_id is None:
self._reset_id = self._current_id = rpc.Rpc().snapshot()
_notify_registry(0)
else:
self._reset_id = self._current_id = self._revert(self._reset_id)
return web3.eth.block_number
def undo(self, num: int = 1) -> int:
"""
Undo one or more transactions.
Arguments
---------
num : int, optional
Number of transactions to undo.
Returns
-------
int
Current block height
"""
with self._undo_lock:
if num < 1:
raise ValueError("num must be greater than zero")
if not self._undo_buffer:
raise ValueError("Undo buffer is empty")
if num > len(self._undo_buffer):
raise ValueError(f"Undo buffer contains {len(self._undo_buffer)} items")
for i in range(num):
id_, fn, args, kwargs = self._undo_buffer.pop()
self._redo_buffer.append((fn, args, kwargs))
self._current_id = self._revert(id_)
return web3.eth.block_number
def redo(self, num: int = 1) -> int:
"""
Redo one or more undone transactions.
Arguments
---------
num : int, optional
Number of transactions to redo.
Returns
-------
int
Current block height
"""
with self._undo_lock:
if num < 1:
raise ValueError("num must be greater than zero")
if not self._redo_buffer:
raise ValueError("Redo buffer is empty")
if num > len(self._redo_buffer):
raise ValueError(f"Redo buffer contains {len(self._redo_buffer)} items")
for i in range(num):
fn, args, kwargs = self._redo_buffer.pop()
fn(*args, **kwargs)
return web3.eth.block_number
# objects that will update whenever the RPC is reset or reverted must register
# by calling to this function. The must also include _revert and _reset methods
# to recieve notifications from this object
def _revert_register(obj: object) -> None:
_revert_refs.append(weakref.ref(obj))
def _notify_registry(height: int = None) -> None:
gc.collect()
if height is None:
height = web3.eth.block_number
for ref in _revert_refs.copy():
obj = ref()
if obj is None:
_revert_refs.remove(ref)
elif height:
obj._revert(height)
else:
obj._reset()
def _find_contract(address: Any) -> Any:
if address is None:
return
address = _resolve_address(address)
if address in _contract_map:
return _contract_map[address]
if "chainid" in CONFIG.active_network:
try:
from brownie.network.contract import Contract
return Contract(address)
except (ValueError, CompilerError):
pass
def _get_current_dependencies() -> List:
dependencies = set(v._name for v in _contract_map.values())
for contract in _contract_map.values():
dependencies.update(contract._build.get("dependencies", []))
return sorted(dependencies)
def _add_contract(contract: Any) -> None:
_contract_map[contract.address] = contract
def _remove_contract(contract: Any) -> None:
del _contract_map[contract.address]
def _get_deployment(
address: str = None, alias: str = None
) -> Tuple[Optional[Dict], Optional[Dict]]:
if address and alias:
raise ValueError("Passed both params address and alias, should be only one!")
if address:
address = _resolve_address(address)
query = f"address='{address}'"
elif alias:
query = f"alias='{alias}'"
try:
name = f"chain{CONFIG.active_network['chainid']}"
except KeyError:
raise BrownieEnvironmentError("Functionality not available in local environment") from None
try:
row = cur.fetchone(f"SELECT * FROM {name} WHERE {query}")
except OperationalError:
row = None
if not row:
return None, None
keys = ["address", "alias", "paths"] + DEPLOYMENT_KEYS
build_json = {k: v for k, v in zip(keys, row)}
path_map = build_json.pop("paths")
sources = {
i[1]: cur.fetchone("SELECT source FROM sources WHERE hash=?", (i[0],))[0]
for i in path_map.values()
}
build_json["allSourcePaths"] = {k: v[1] for k, v in path_map.items()}
if isinstance(build_json["pcMap"], dict):
build_json["pcMap"] = dict((int(k), v) for k, v in build_json["pcMap"].items())
return build_json, sources
def _add_deployment(contract: Any, alias: Optional[str] = None) -> None:
if "chainid" not in CONFIG.active_network:
return
address = _resolve_address(contract.address)
name = f"chain{CONFIG.active_network['chainid']}"
cur.execute(
f"CREATE TABLE IF NOT EXISTS {name} "
f"(address UNIQUE, alias UNIQUE, paths, {', '.join(DEPLOYMENT_KEYS)})"
)
all_sources = {}
for key, path in contract._build.get("allSourcePaths", {}).items():
source = contract._sources.get(path)
if source is None:
source = Path(path).read_text()
hash_ = sha1(source.encode()).hexdigest()
cur.insert("sources", hash_, source)
all_sources[key] = [hash_, path]
values = [contract._build.get(i) for i in DEPLOYMENT_KEYS]
cur.insert(name, address, alias, all_sources, *values)
def _remove_deployment(
address: str = None, alias: str = None
) -> Tuple[Optional[Dict], Optional[Dict]]:
if address and alias:
raise ValueError("Passed both params address and alias, should be only one!")
if address:
address = _resolve_address(address)
query = f"address='{address}'"
elif alias:
query = f"alias='{alias}'"
try:
name = f"chain{CONFIG.active_network['chainid']}"
except KeyError:
raise BrownieEnvironmentError("Functionality not available in local environment") from None
build_json, sources = _get_deployment(address, alias)
# delete entry from chain{n}
cur.execute(f"DELETE FROM {name} WHERE {query}")
# delete all entries from sources matching the contract's source hashes
contract = _find_contract(address)
if contract:
for key, path in contract._build.get("allSourcePaths", {}).items():
source = contract._sources.get(path)
if source is None:
source = Path(path).read_text()
hash_ = sha1(source.encode()).hexdigest()
cur.execute(f"DELETE FROM sources WHERE hash='{hash_}'")
return build_json, sources