Skip to content
This repository was archived by the owner on Mar 11, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions stake-pool/program/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -942,9 +942,13 @@ impl Fee {
if self.denominator == 0 {
return Some(0);
}
(amt as u128)
.checked_mul(self.numerator as u128)?
.checked_div(self.denominator as u128)
let numerator = (amt as u128).checked_mul(self.numerator as u128)?;
// ceiling the calculation by adding (denominator - 1) to the numerator
let denominator = self.denominator as u128;
numerator
.checked_add(denominator)?
.checked_sub(1)?
.checked_div(denominator)
}

/// Withdrawal fees have some additional restrictions,
Expand Down
11 changes: 7 additions & 4 deletions stake-pool/program/tests/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -897,15 +897,17 @@ impl StakePoolAccounts {
}

pub fn calculate_fee(&self, amount: u64) -> u64 {
amount * self.epoch_fee.numerator / self.epoch_fee.denominator
(amount * self.epoch_fee.numerator + self.epoch_fee.denominator - 1)
/ self.epoch_fee.denominator
}

pub fn calculate_withdrawal_fee(&self, pool_tokens: u64) -> u64 {
pool_tokens * self.withdrawal_fee.numerator / self.withdrawal_fee.denominator
(pool_tokens * self.withdrawal_fee.numerator + self.withdrawal_fee.denominator - 1)
/ self.withdrawal_fee.denominator
}

pub fn calculate_inverse_withdrawal_fee(&self, pool_tokens: u64) -> u64 {
pool_tokens * self.withdrawal_fee.denominator
(pool_tokens * self.withdrawal_fee.denominator + self.withdrawal_fee.denominator - 1)
/ (self.withdrawal_fee.denominator - self.withdrawal_fee.numerator)
}

Expand All @@ -914,7 +916,8 @@ impl StakePoolAccounts {
}

pub fn calculate_sol_deposit_fee(&self, pool_tokens: u64) -> u64 {
pool_tokens * self.sol_deposit_fee.numerator / self.sol_deposit_fee.denominator
(pool_tokens * self.sol_deposit_fee.numerator + self.sol_deposit_fee.denominator - 1)
/ self.sol_deposit_fee.denominator
}

pub fn calculate_sol_referral_fee(&self, deposit_fee_collected: u64) -> u64 {
Expand Down
4 changes: 2 additions & 2 deletions stake-pool/program/tests/withdraw_sol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,15 @@ async fn fail_overdraw_reserve() {
.await;
assert!(error.is_none(), "{:?}", error);

// try to withdraw one lamport, will overdraw
// try to withdraw one lamport after fees, will overdraw
let error = stake_pool_accounts
.withdraw_sol(
&mut context.banks_client,
&context.payer,
&context.last_blockhash,
&user,
&pool_token_account,
1,
2,
None,
)
.await
Expand Down
18 changes: 3 additions & 15 deletions stake-pool/program/tests/withdraw_with_fee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ mod helpers;
use {
bincode::deserialize,
helpers::*,
solana_program::{borsh0_10::try_from_slice_unchecked, pubkey::Pubkey, stake},
solana_program::{pubkey::Pubkey, stake},
solana_program_test::*,
solana_sdk::signature::{Keypair, Signer},
spl_stake_pool::{minimum_stake_lamports, state},
spl_stake_pool::minimum_stake_lamports,
};

#[tokio::test]
Expand Down Expand Up @@ -183,20 +183,8 @@ async fn success_empty_out_stake_with_fee() {
.await;
let lamports_to_withdraw =
validator_stake_account.lamports - minimum_stake_lamports(&meta, stake_minimum_delegation);
let stake_pool_account = get_account(
&mut context.banks_client,
&stake_pool_accounts.stake_pool.pubkey(),
)
.await;
let stake_pool =
try_from_slice_unchecked::<state::StakePool>(stake_pool_account.data.as_slice()).unwrap();
let fee = stake_pool.stake_withdrawal_fee;
let inverse_fee = state::Fee {
numerator: fee.denominator - fee.numerator,
denominator: fee.denominator,
};
let pool_tokens_to_withdraw =
lamports_to_withdraw * inverse_fee.denominator / inverse_fee.numerator;
stake_pool_accounts.calculate_inverse_withdrawal_fee(lamports_to_withdraw);

let last_blockhash = context
.banks_client
Expand Down