RefundService.java 9.43 KB
package com.ecommerce.payment.service;

import com.ecommerce.payment.model.Payment;
import com.ecommerce.payment.model.Refund;
import com.ecommerce.payment.model.dto.RefundRequest;
import com.ecommerce.payment.repository.PaymentRepository;
import com.ecommerce.payment.repository.RefundRepository;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.cache.annotation.Cacheable;
import org.springframework.cache.annotation.Caching;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

import java.math.BigDecimal;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Slf4j
@Service
@RequiredArgsConstructor
public class RefundService {
    
    private final RefundRepository refundRepository;
    private final PaymentRepository paymentRepository;
    private final StripeService stripeService;
    private final PayPalService payPalService;
    private final RabbitMQService rabbitMQService;
    
    @Transactional
    @Caching(evict = {
        @CacheEvict(value = "refunds", allEntries = true),
        @CacheEvict(value = "refund", key = "#result.refundId")
    })
    public Map<String, Object> createRefund(RefundRequest request) {
        Payment payment = paymentRepository.findByPaymentId(request.getPaymentId())
                .orElseThrow(() -> new RuntimeException("Payment not found: " + request.getPaymentId()));
        
        // Validate refund amount
        validateRefundAmount(payment, request.getAmount());
        
        // Create refund record
        Refund refund = new Refund();
        refund.setPayment(payment);
        refund.setAmount(request.getAmount());
        refund.setCurrency(request.getCurrency());
        refund.setReason(request.getReason());
        refund.setStatus("PENDING");
        
        Refund savedRefund = refundRepository.save(refund);
        
        // Process refund
        Map<String, Object> result = processRefund(savedRefund);
        
        log.info("Refund created: {}", savedRefund.getRefundId());
        return result;
    }
    
    @Cacheable(value = "refund", key = "#refundId")
    public Map<String, Object> getRefund(String refundId) {
        Refund refund = refundRepository.findByRefundId(refundId)
                .orElseThrow(() -> new RuntimeException("Refund not found: " + refundId));
        return mapRefundToResponse(refund);
    }
    
    @Cacheable(value = "refunds", key = "#paymentId")
    public List<Map<String, Object>> getRefundsByPayment(String paymentId) {
        Payment payment = paymentRepository.findByPaymentId(paymentId)
                .orElseThrow(() -> new RuntimeException("Payment not found: " + paymentId));
        
        List<Refund> refunds = refundRepository.findByPaymentId(payment.getId());
        return refunds.stream()
                .map(this::mapRefundToResponse)
                .collect(Collectors.toList());
    }
    
    @Cacheable(value = "refunds", key = "#status + '-' + #pageable.pageNumber")
    public Page<Map<String, Object>> getRefundsByStatus(String status, Pageable pageable) {
        return refundRepository.findByStatus(status, pageable)
                .map(this::mapRefundToResponse);
    }
    
    @Cacheable(value = "refunds", key = "#pageable.pageNumber + '-' + #pageable.pageSize")
    public Page<Map<String, Object>> getAllRefunds(Pageable pageable) {
        return refundRepository.findAll(pageable)
                .map(this::mapRefundToResponse);
    }
    
    @Transactional
    @Caching(evict = {
        @CacheEvict(value = "refunds", allEntries = true),
        @CacheEvict(value = "refund", key = "#refundId")
    })
    public Map<String, Object> cancelRefund(String refundId, String reason) {
        Refund refund = refundRepository.findByRefundId(refundId)
                .orElseThrow(() -> new RuntimeException("Refund not found: " + refundId));
        
        if (!"PENDING".equals(refund.getStatus()) && !"PROCESSING".equals(refund.getStatus())) {
            throw new RuntimeException("Refund cannot be cancelled in current status: " + refund.getStatus());
        }
        
        refund.setStatus("CANCELLED");
        refund.setFailureReason(reason);
        
        Refund cancelledRefund = refundRepository.save(refund);
        
        log.info("Refund cancelled: {}", refundId);
        return mapRefundToResponse(cancelledRefund);
    }
    
    public Map<String, Object> getRefundStatistics(LocalDateTime startDate, LocalDateTime endDate) {
        BigDecimal totalRefunds = refundRepository.getTotalRefundsByDateRange(startDate, endDate);
        
        Map<String, Object> stats = new HashMap<>();
        stats.put("totalRefunds", totalRefunds);
        stats.put("startDate", startDate);
        stats.put("endDate", endDate);
        
        return stats;
    }
    
    private void validateRefundAmount(Payment payment, BigDecimal refundAmount) {
        if (!"SUCCEEDED".equals(payment.getStatus())) {
            throw new RuntimeException("Cannot refund payment that is not succeeded");
        }
        
        if (refundAmount.compareTo(BigDecimal.ZERO) <= 0) {
            throw new RuntimeException("Refund amount must be greater than 0");
        }
        
        if (refundAmount.compareTo(payment.getAmount()) > 0) {
            throw new RuntimeException("Refund amount cannot exceed payment amount");
        }
        
        // Check if there are existing refunds
        List<Refund> existingRefunds = refundRepository.findByPaymentId(payment.getId());
        BigDecimal totalRefunded = existingRefunds.stream()
                .filter(r -> "SUCCEEDED".equals(r.getStatus()))
                .map(Refund::getAmount)
                .reduce(BigDecimal.ZERO, BigDecimal::add);
        
        BigDecimal remainingAmount = payment.getAmount().subtract(totalRefunded);
        if (refundAmount.compareTo(remainingAmount) > 0) {
            throw new RuntimeException("Refund amount exceeds remaining refundable amount: " + remainingAmount);
        }
    }
    
    private Map<String, Object> processRefund(Refund refund) {
        try {
            boolean success;
            String gatewayRefundId;
            
            Payment payment = refund.getPayment();
            
            switch (payment.getPaymentGateway().toUpperCase()) {
                case "STRIPE":
                    Map<String, Object> stripeResult = stripeService.createRefund(refund);
                    success = (boolean) stripeResult.get("success");
                    gatewayRefundId = (String) stripeResult.get("gatewayRefundId");
                    break;
                case "PAYPAL":
                    Map<String, Object> paypalResult = payPalService.createRefund(refund);
                    success = (boolean) paypalResult.get("success");
                    gatewayRefundId = (String) paypalResult.get("gatewayRefundId");
                    break;
                default:
                    // Simulate refund for other payment methods
                    success = Math.random() > 0.05; // 95% success rate
                    gatewayRefundId = "REF_" + System.currentTimeMillis();
                    break;
            }
            
            if (success) {
                refund.setStatus("SUCCEEDED");
                refund.setGatewayRefundId(gatewayRefundId);
                refund.setProcessedAt(LocalDateTime.now());
                
                // Update payment status if full refund
                if (refund.getAmount().compareTo(payment.getAmount()) == 0) {
                    payment.setStatus("REFUNDED");
                    payment.setRefundedAt(LocalDateTime.now());
                    paymentRepository.save(payment);
                }
                
                rabbitMQService.sendRefundSuccessEvent(refund);
            } else {
                refund.setStatus("FAILED");
                refund.setFailureReason("Refund processing failed");
            }
            
            Refund processedRefund = refundRepository.save(refund);
            
            Map<String, Object> response = mapRefundToResponse(processedRefund);
            response.put("success", success);
            
            return response;
            
        } catch (Exception e) {
            refund.setStatus("FAILED");
            refund.setFailureReason(e.getMessage());
            refundRepository.save(refund);
            
            throw new RuntimeException("Refund processing failed: " + e.getMessage());
        }
    }
    
    private Map<String, Object> mapRefundToResponse(Refund refund) {
        Map<String, Object> response = new HashMap<>();
        response.put("refundId", refund.getRefundId());
        response.put("paymentId", refund.getPayment().getPaymentId());
        response.put("amount", refund.getAmount());
        response.put("currency", refund.getCurrency());
        response.put("status", refund.getStatus());
        response.put("reason", refund.getReason());
        response.put("gatewayRefundId", refund.getGatewayRefundId());
        response.put("failureReason", refund.getFailureReason());
        response.put("failureCode", refund.getFailureCode());
        response.put("createdAt", refund.getCreatedAt());
        response.put("processedAt", refund.getProcessedAt());
        return response;
    }
}