Skip to content

Commit 2ae2a71

Browse files
committed
Salesforce fixes
1 parent 349691f commit 2ae2a71

File tree

1 file changed

+84
-27
lines changed

1 file changed

+84
-27
lines changed

examples/salesforce/python/salesforce_endpoints.py

Lines changed: 84 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,42 +5,53 @@
55
This is a simpler alternative to the plugin-based approach.
66
"""
77

8-
from typing import Dict, Any, List, Optional
8+
from typing import Dict, Any, List, Optional, Callable
99
import logging
10-
import json
10+
import time
11+
import functools
12+
import threading
1113
import simple_salesforce
14+
from simple_salesforce.exceptions import SalesforceExpiredSession
15+
1216
from mxcp.runtime import config, on_init, on_shutdown
1317

1418
logger = logging.getLogger(__name__)
1519

1620
# Global Salesforce client for reuse across all function calls
1721
sf_client: Optional[simple_salesforce.Salesforce] = None
22+
# Thread lock to protect client initialization
23+
_client_lock = threading.Lock()
1824

1925

2026
@on_init
2127
def setup_salesforce_client():
22-
"""Initialize Salesforce client when server starts."""
23-
global sf_client
24-
logger.info("Initializing Salesforce client...")
25-
26-
sf_config = config.get_secret("salesforce")
27-
if not sf_config:
28-
raise ValueError("Salesforce configuration not found. Please configure Salesforce secrets in your user config.")
29-
30-
required_keys = ["username", "password", "security_token", "instance_url", "client_id"]
31-
missing_keys = [key for key in required_keys if not sf_config.get(key)]
32-
if missing_keys:
33-
raise ValueError(f"Missing Salesforce configuration keys: {', '.join(missing_keys)}")
28+
"""Initialize Salesforce client when server starts.
3429
35-
sf_client = simple_salesforce.Salesforce(
36-
username=sf_config["username"],
37-
password=sf_config["password"],
38-
security_token=sf_config["security_token"],
39-
instance_url=sf_config["instance_url"],
40-
client_id=sf_config["client_id"]
41-
)
30+
Thread-safe: multiple threads can safely call this simultaneously.
31+
"""
32+
global sf_client
4233

43-
logger.info("Salesforce client initialized successfully")
34+
with _client_lock:
35+
logger.info("Initializing Salesforce client...")
36+
37+
sf_config = config.get_secret("salesforce")
38+
if not sf_config:
39+
raise ValueError("Salesforce configuration not found. Please configure Salesforce secrets in your user config.")
40+
41+
required_keys = ["username", "password", "security_token", "instance_url", "client_id"]
42+
missing_keys = [key for key in required_keys if not sf_config.get(key)]
43+
if missing_keys:
44+
raise ValueError(f"Missing Salesforce configuration keys: {', '.join(missing_keys)}")
45+
46+
sf_client = simple_salesforce.Salesforce(
47+
username=sf_config["username"],
48+
password=sf_config["password"],
49+
security_token=sf_config["security_token"],
50+
instance_url=sf_config["instance_url"],
51+
client_id=sf_config["client_id"]
52+
)
53+
54+
logger.info("Salesforce client initialized successfully")
4455

4556

4657
@on_shutdown
@@ -53,13 +64,55 @@ def cleanup_salesforce_client():
5364
logger.info("Salesforce client cleaned up")
5465

5566

67+
def retry_on_session_expiration(func: Callable) -> Callable:
68+
"""
69+
Decorator that automatically retries functions on session expiration.
70+
71+
This only retries on SalesforceExpiredSession, not SalesforceAuthenticationFailed.
72+
Authentication failures (wrong credentials) should not be retried.
73+
74+
Retries up to 2 times on session expiration (3 total attempts).
75+
Thread-safe: setup_salesforce_client() handles concurrent access internally.
76+
77+
Usage:
78+
@retry_on_session_expiration
79+
def my_salesforce_function():
80+
# Function that might fail due to session expiration
81+
pass
82+
"""
83+
@functools.wraps(func)
84+
def wrapper(*args, **kwargs):
85+
max_retries = 2 # Hardcoded: 2 retries = 3 total attempts
86+
87+
for attempt in range(max_retries + 1):
88+
try:
89+
return func(*args, **kwargs)
90+
except SalesforceExpiredSession as e:
91+
if attempt < max_retries:
92+
logger.warning(f"Session expired on attempt {attempt + 1} in {func.__name__}: {e}")
93+
logger.info(f"Retrying after re-initializing client (attempt {attempt + 2}/{max_retries + 1})")
94+
95+
try:
96+
setup_salesforce_client() # Thread-safe internally
97+
time.sleep(0.1) # Small delay to avoid immediate retry
98+
except Exception as setup_error:
99+
logger.error(f"Failed to re-initialize Salesforce client: {setup_error}")
100+
raise setup_error # Raise the setup error, not the original session error
101+
else:
102+
# Last attempt failed, re-raise the session expiration error
103+
raise e
104+
105+
return wrapper
106+
107+
56108
def _get_salesforce_client() -> simple_salesforce.Salesforce:
57109
"""Get the global Salesforce client."""
58110
if sf_client is None:
59111
raise RuntimeError("Salesforce client not initialized. Make sure the server is started properly.")
60112
return sf_client
61113

62114

115+
@retry_on_session_expiration
63116
def soql(query: str) -> List[Dict[str, Any]]:
64117
"""Execute an SOQL query against Salesforce.
65118
@@ -84,6 +137,7 @@ def soql(query: str) -> List[Dict[str, Any]]:
84137
]
85138

86139

140+
@retry_on_session_expiration
87141
def sosl(query: str) -> List[Dict[str, Any]]:
88142
"""Execute a SOSL query against Salesforce.
89143
@@ -105,6 +159,7 @@ def sosl(query: str) -> List[Dict[str, Any]]:
105159
return result.get('searchRecords', [])
106160

107161

162+
@retry_on_session_expiration
108163
def search(search_term: str) -> List[Dict[str, Any]]:
109164
"""Search across all Salesforce objects using a simple search term.
110165
@@ -125,6 +180,7 @@ def search(search_term: str) -> List[Dict[str, Any]]:
125180
return sosl(sosl_query)
126181

127182

183+
@retry_on_session_expiration
128184
def list_sobjects(filter: Optional[str] = None) -> List[str]:
129185
"""List all available Salesforce objects (sObjects) in the org.
130186
@@ -136,23 +192,22 @@ def list_sobjects(filter: Optional[str] = None) -> List[str]:
136192
list: List of Salesforce object names as strings
137193
"""
138194
sf = _get_salesforce_client()
139-
140195
describe_result = sf.describe()
141-
196+
142197
object_names = [obj['name'] for obj in describe_result['sobjects']]
143-
198+
144199
if filter is not None and filter.strip():
145200
filter_lower = filter.lower()
146201
object_names = [
147202
name for name in object_names
148203
if filter_lower in name.lower()
149204
]
150-
205+
151206
object_names.sort()
152-
153207
return object_names
154208

155209

210+
@retry_on_session_expiration
156211
def describe_sobject(sobject_name: str) -> Dict[str, Any]:
157212
"""Get the description of a Salesforce object type.
158213
@@ -195,6 +250,8 @@ def describe_sobject(sobject_name: str) -> Dict[str, Any]:
195250

196251
return fields_info
197252

253+
254+
@retry_on_session_expiration
198255
def get_sobject(sobject_name: str, record_id: str) -> Dict[str, Any]:
199256
"""Get a specific Salesforce object by its ID.
200257

0 commit comments

Comments
 (0)